diff --git a/src/zarr/store/logging.py b/src/zarr/store/logging.py new file mode 100644 index 000000000..792dc66d9 --- /dev/null +++ b/src/zarr/store/logging.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import inspect +import logging +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import TYPE_CHECKING + +from zarr.abc.store import AccessMode, Store + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Generator + + from zarr.core.buffer import Buffer, BufferPrototype + + +class LoggingStore(Store): + _store: Store + counter: defaultdict[str, int] + + def __init__( + self, + store: Store, + log_level: str = "DEBUG", + log_handler: logging.Handler | None = None, + ): + self._store = store + self.counter = defaultdict(int) + + self._configure_logger(log_level, log_handler) + + def _configure_logger( + self, log_level: str = "DEBUG", log_handler: logging.Handler | None = None + ) -> None: + self.log_level = log_level + self.logger = logging.getLogger(f"LoggingStore({self._store!s})") + self.logger.setLevel(log_level) + + if not self.logger.hasHandlers(): + if not log_handler: + log_handler = self._default_handler() + # Add handler to logger + self.logger.addHandler(log_handler) + + def _default_handler(self) -> logging.Handler: + """Define a default log handler""" + handler = logging.StreamHandler() + handler.setLevel(self.log_level) + handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) + return handler + + @contextmanager + def log(self) -> Generator[None, None, None]: + method = inspect.stack()[2].function + op = f"{type(self._store).__name__}.{method}" + self.logger.info(f"Calling {op}") + start_time = time.time() + try: + self.counter[method] += 1 + yield + finally: + end_time = time.time() + self.logger.info(f"Finished {op} in {end_time - start_time:.2f} seconds") + + @property + def supports_writes(self) -> bool: + with self.log(): + return self._store.supports_writes + + @property + def supports_deletes(self) -> bool: + with self.log(): + return self._store.supports_deletes + + @property + def supports_partial_writes(self) -> bool: + with self.log(): + return self._store.supports_partial_writes + + @property + def supports_listing(self) -> bool: + with self.log(): + return self._store.supports_listing + + @property + def _mode(self) -> AccessMode: # type: ignore[override] + with self.log(): + return self._store._mode + + @property + def _is_open(self) -> bool: # type: ignore[override] + with self.log(): + return self._store._is_open + + async def empty(self) -> bool: + with self.log(): + return await self._store.empty() + + async def clear(self) -> None: + with self.log(): + return await self._store.clear() + + def __str__(self) -> str: + return f"logging-{self._store!s}" + + def __repr__(self) -> str: + return f"LoggingStore({repr(self._store)!r})" + + def __eq__(self, other: object) -> bool: + with self.log(): + return self._store == other + + async def get( + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int | None, int | None] | None = None, + ) -> Buffer | None: + with self.log(): + return await self._store.get(key=key, prototype=prototype, byte_range=byte_range) + + async def get_partial_values( + self, + prototype: BufferPrototype, + key_ranges: list[tuple[str, tuple[int | None, int | None]]], + ) -> list[Buffer | None]: + with self.log(): + return await self._store.get_partial_values(prototype=prototype, key_ranges=key_ranges) + + async def exists(self, key: str) -> bool: + with self.log(): + return await self._store.exists(key) + + async def set(self, key: str, value: Buffer) -> None: + with self.log(): + return await self._store.set(key=key, value=value) + + async def delete(self, key: str) -> None: + with self.log(): + return await self._store.delete(key=key) + + async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + with self.log(): + return await self._store.set_partial_values(key_start_values=key_start_values) + + async def list(self) -> AsyncGenerator[str, None]: + with self.log(): + async for key in self._store.list(): + yield key + + async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: + with self.log(): + async for key in self._store.list_prefix(prefix=prefix): + yield key + + async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: + with self.log(): + async for key in self._store.list_dir(prefix=prefix): + yield key diff --git a/tests/v3/test_store/test_logging.py b/tests/v3/test_store/test_logging.py new file mode 100644 index 000000000..a263c2ae0 --- /dev/null +++ b/tests/v3/test_store/test_logging.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import zarr +from zarr.core.buffer import default_buffer_prototype +from zarr.store.logging import LoggingStore + +if TYPE_CHECKING: + from zarr.abc.store import Store + + +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +async def test_logging_store(store: Store, caplog) -> None: + wrapped = LoggingStore(store=store, log_level="DEBUG") + buffer = default_buffer_prototype().buffer + + caplog.clear() + res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04")) + assert res is None + assert len(caplog.record_tuples) == 2 + for tup in caplog.record_tuples: + assert str(store) in tup[0] + assert f"Calling {type(store).__name__}.set" in caplog.record_tuples[0][2] + assert f"Finished {type(store).__name__}.set" in caplog.record_tuples[1][2] + + caplog.clear() + keys = [k async for k in wrapped.list()] + assert keys == ["foo/bar/c/0"] + assert len(caplog.record_tuples) == 2 + for tup in caplog.record_tuples: + assert str(store) in tup[0] + assert f"Calling {type(store).__name__}.list" in caplog.record_tuples[0][2] + assert f"Finished {type(store).__name__}.list" in caplog.record_tuples[1][2] + + +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +async def test_logging_store_counter(store: Store) -> None: + wrapped = LoggingStore(store=store, log_level="DEBUG") + + arr = zarr.create(shape=(10,), store=wrapped, overwrite=True) + arr[:] = 1 + + assert wrapped.counter["set"] == 2 + assert wrapped.counter["get"] == 0 # 1 if overwrite=False + assert wrapped.counter["list"] == 0 + assert wrapped.counter["list_dir"] == 0 + assert wrapped.counter["list_prefix"] == 0