Skip to content

Commit

Permalink
fixed wrapper class and add rules for a few more methods
Browse files Browse the repository at this point in the history
  • Loading branch information
e-marshall committed Aug 12, 2024
1 parent 0bd1932 commit 619e043
Showing 1 changed file with 79 additions and 58 deletions.
137 changes: 79 additions & 58 deletions tests/v3/test_store/test_stateful_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,73 +11,77 @@
rule,
)

from zarr.abc.store import Store as StoreABC
from zarr.buffer import Buffer, default_buffer_prototype
from zarr.store import LocalStore, MemoryStore, RemoteStore
from zarr.store import MemoryStore

# from strategies_store import StoreStatefulStrategies, key_ranges
from zarr.testing.strategies import key_ranges, paths

# zarr spec: https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html


class SyncStoreWrapper():
class SyncStoreWrapper:
def __init__(self, store):
"""Store to hold async functions that map to StoreABC abstract methods"""
"""Class to hold sync functions that map to async methods of MemoryStore
MemoryStore methods are async, this class' methods are sync, so just need to call asyncio.run() in them
then, methods in statemachine class are sync and call sync
"""
self.store = store
# Unfortunately, hypothesis' stateful testing infra does not support asyncio
self.mode = store.mode

# Unfortunately, hypothesis' stateful testing infra does not support asyncio
# So we redefine sync versions of the Store API.
# https://github.com/HypothesisWorks/hypothesis/issues/3712#issuecomment-1668999041
def set(self, key, data_buffer): # buffer is value
return asyncio.run(self.store.set(key, data_buffer))

async def list(self):
paths = [path async for path in self.store.list()]
# note(f'(store set) store {paths=}, type {type(paths)=}, {len(paths)=}')
return paths
def list(self):
async def wrapper(gen):
return [i async for i in gen]

gen = self.store.list() # async store list
return (i for i in asyncio.run(wrapper(gen)))

async def get(
self,
key,
):
obs = await self.store.get(key, prototype=default_buffer_prototype())
def get(self, key):
obs = asyncio.run(self.store.get(key, prototype=default_buffer_prototype()))
return obs

async def get_partial_values(self, key_ranges):
obs_maybe = await self.store.get_partial_values(
prototype=default_buffer_prototype(), key_ranges=key_ranges
def get_partial_values(self, key_ranges):
obs_partial = asyncio.run(
self.store.get_partial_values(
prototype=default_buffer_prototype(), key_ranges=key_ranges
)
)
return obs_maybe
return obs_partial

async def delete(self, path): # path is key
await self.store.delete(path)
def delete(self, path): # path is key
return asyncio.run(self.store.delete(path))

async def empty(self):
await self.store.empty()
def empty(self):
return asyncio.run(self.store.empty())

async def clear(self):
await self.store.clear()
def clear(self):
return asyncio.run(self.store.clear())

async def exists(self, key):
raise NotImplementedError
def exists(self, key):
return asyncio.run(self.store.exists(key))

async def list_dir(self, prefix):
def list_dir(self, prefix):
raise NotImplementedError

async def list_prefix(self, prefix: str):
def list_prefix(self, prefix: str):
raise NotImplementedError

async def set_partial_values(self, key_start_values):
def set_partial_values(self, key_start_values):
raise NotImplementedError

async def supports_listing(self):
def supports_listing(self):
raise NotImplementedError

async def supports_partial_writes(self):
def supports_partial_writes(self):
raise NotImplementedError

async def supports_writes(self):
def supports_writes(self):
raise NotImplementedError


Expand All @@ -86,44 +90,52 @@ def __init__(self): # look into using run_machine_as_test()
super().__init__()
self.model = {}
self.store = SyncStoreWrapper(MemoryStore(mode="w"))
# self.store = MemoryStore(mode='w')

@rule(key=paths, data=st.binary(min_size=0, max_size=100))
def set(self, key: str, data: bytes) -> None:
note(f"(set) Setting {key!r} with {data}")
assert not self.store.mode.readonly
data_buf = Buffer.from_bytes(data)
self.store.set(key, data_buf))
self.store.set(key, data_buf)
self.model[key] = data_buf # this was data

@invariant()
def check_paths_equal(self) -> None:
note("Checking that paths are equal")
paths = self.store.list()
paths = list(self.store.list())

assert list(self.model.keys()) == paths

@invariant()
def check_vals_equal(self) -> None:
note("Checking values equal")
for key, _val in self.model.items():
store_item = self.store.get(key).to_bytes()
# note(f'(inv) model item: {self.model[key]}')
# note(f'(inv) {store_item=}')
assert self.model[key].to_bytes() == store_item

@invariant()
def check_num_keys_equal(self) -> None:
note("check num keys equal")
model_keys_len = len(list(self.model.keys()))
store_keys_len = len(list(self.store.list()))
assert model_keys_len == store_keys_len

# @rule(key=keys_bundle)
@precondition(lambda self: len(self.model.keys()) > 0)
@rule(data=st.data())
# @rule(key=keys_bundle)
def get(self, data) -> None:
key = data.draw(st.sampled_from(sorted(self.model.keys())))
store_value = asyncio.run(self.sync_wrapper.get(key))
store_value = self.store.get(key)
# to bytes here necessary (on model and store) because data_buf set to model in set()
assert self.model[key].to_bytes() == store_value.to_bytes()

@precondition(lambda self: len(self.model.keys()) > 0)
@rule(data=st.data())
def get_partial_values(self, data) -> None:
key_st = st.sampled_from(sorted(self.model.keys()))
key_st = st.sampled_from(sorted(self.model.keys())) # hypothesis wants you to sort
key_range = data.draw(key_ranges(keys=key_st))

obs_maybe = asyncio.run(self.sync_wrapper.get_partial_values(key_range))
obs_maybe = self.store.get_partial_values(key_range)
observed = []

for obs in obs_maybe:
Expand Down Expand Up @@ -152,31 +164,40 @@ def delete(self, data) -> None:
path_st = data.draw(st.sampled_from(sorted(self.model.keys())))
note(f"(delete) Deleting {path_st=}")

asyncio.run(self.sync_wrapper.delete(path_st))
self.store.delete(path_st)
del self.model[path_st]

@rule(key=paths, data=st.binary(min_size=0, max_size=100))
def clear(self, key: str, data: bytes):
"""clear() is in zarr/store/memory.py
it calls clear on self._store_dict
clear() is dict method that removes all key-val pairs from dict
"""
@rule()
def clear(self):
assert not self.store.mode.readonly
note("(clear)")
asyncio.run(self.sync_wrapper.clear())
self.store.clear()
self.model.clear()

# check that model was cleared
assert len(self.model.keys()) == 0

# @rule()
# def empty(self, data) -> None:
# """empty checks if a store is empty or not
# return true if self._store_dict doesn't exist
# return false if self._store_dict exists"""
# note("(empty)")
@rule()
def empty(self) -> None:
note("(empty)")
# check if store, model are empty
store_empty = self.store.empty()
model_empty = not self.model
# make sure they either both are or both aren't (same state)
assert model_empty == store_empty

@rule(key=paths)
def exists(self, key) -> None:
note("(exists)")

def model_exists(self, key) -> bool:
return key in self.model

# check if given key in model, store
store_exists = self.store.exists(key)
model_exists = model_exists(self, key)

# asyncio.run(self.sync_wrapper.empty())
# assert self.store.empty()
# make sure same state
assert model_exists == store_exists


StatefulStoreTest = ZarrStoreStateMachine.TestCase

0 comments on commit 619e043

Please sign in to comment.