Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(store): add LoggingStore wrapper #2231

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions src/zarr/store/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

import inspect
import logging
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING

from zarr.abc.store import AccessMode, Store

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator

from zarr.core.buffer import Buffer, BufferPrototype


class LoggingStore(Store):
_store: Store
counter: defaultdict[str, int]

def __init__(
self,
store: Store,
log_level: str = "DEBUG",
log_handler: logging.Handler | None = None,
):
self._store = store
self.counter = defaultdict(int)

self._configure_logger(log_level, log_handler)

def _configure_logger(
self, log_level: str = "DEBUG", log_handler: logging.Handler | None = None
) -> None:
self.log_level = log_level
self.logger = logging.getLogger(f"LoggingStore({self._store!s})")
self.logger.setLevel(log_level)

if not self.logger.hasHandlers():
if not log_handler:
log_handler = self._default_handler()
# Add handler to logger
self.logger.addHandler(log_handler)

def _default_handler(self) -> logging.Handler:
"""Define a default log handler"""
handler = logging.StreamHandler()
handler.setLevel(self.log_level)
handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
return handler

@contextmanager
def log(self) -> Generator[None, None, None]:
method = inspect.stack()[2].function
op = f"{type(self._store).__name__}.{method}"
self.logger.info(f"Calling {op}")
start_time = time.time()
try:
self.counter[method] += 1
yield
finally:
end_time = time.time()
self.logger.info(f"Finished {op} in {end_time - start_time:.2f} seconds")

@property
def supports_writes(self) -> bool:
with self.log():
return self._store.supports_writes

@property
def supports_deletes(self) -> bool:
with self.log():
return self._store.supports_deletes

@property
def supports_partial_writes(self) -> bool:
with self.log():
return self._store.supports_partial_writes

@property
def supports_listing(self) -> bool:
with self.log():
return self._store.supports_listing

@property
def _mode(self) -> AccessMode: # type: ignore[override]
with self.log():
return self._store._mode

@property
def _is_open(self) -> bool: # type: ignore[override]
with self.log():
return self._store._is_open

async def empty(self) -> bool:
with self.log():
return await self._store.empty()

async def clear(self) -> None:
with self.log():
return await self._store.clear()

def __str__(self) -> str:
return f"logging-{self._store!s}"

def __repr__(self) -> str:
return f"LoggingStore({repr(self._store)!r})"

def __eq__(self, other: object) -> bool:
with self.log():
return self._store == other

async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: tuple[int | None, int | None] | None = None,
) -> Buffer | None:
with self.log():
return await self._store.get(key=key, prototype=prototype, byte_range=byte_range)

async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: list[tuple[str, tuple[int | None, int | None]]],
) -> list[Buffer | None]:
with self.log():
return await self._store.get_partial_values(prototype=prototype, key_ranges=key_ranges)

async def exists(self, key: str) -> bool:
with self.log():
return await self._store.exists(key)

async def set(self, key: str, value: Buffer) -> None:
with self.log():
return await self._store.set(key=key, value=value)

async def delete(self, key: str) -> None:
with self.log():
return await self._store.delete(key=key)

async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None:
with self.log():
return await self._store.set_partial_values(key_start_values=key_start_values)

async def list(self) -> AsyncGenerator[str, None]:
with self.log():
async for key in self._store.list():
yield key

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
with self.log():
async for key in self._store.list_prefix(prefix=prefix):
yield key

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
with self.log():
async for key in self._store.list_dir(prefix=prefix):
yield key
50 changes: 50 additions & 0 deletions tests/v3/test_store/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import zarr
from zarr.core.buffer import default_buffer_prototype
from zarr.store.logging import LoggingStore

if TYPE_CHECKING:
from zarr.abc.store import Store


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
async def test_logging_store(store: Store, caplog) -> None:
wrapped = LoggingStore(store=store, log_level="DEBUG")
buffer = default_buffer_prototype().buffer

caplog.clear()
res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04"))
assert res is None
assert len(caplog.record_tuples) == 2
for tup in caplog.record_tuples:
assert str(store) in tup[0]
assert f"Calling {type(store).__name__}.set" in caplog.record_tuples[0][2]
assert f"Finished {type(store).__name__}.set" in caplog.record_tuples[1][2]

caplog.clear()
keys = [k async for k in wrapped.list()]
assert keys == ["foo/bar/c/0"]
assert len(caplog.record_tuples) == 2
for tup in caplog.record_tuples:
assert str(store) in tup[0]
assert f"Calling {type(store).__name__}.list" in caplog.record_tuples[0][2]
assert f"Finished {type(store).__name__}.list" in caplog.record_tuples[1][2]


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
async def test_logging_store_counter(store: Store) -> None:
wrapped = LoggingStore(store=store, log_level="DEBUG")

arr = zarr.create(shape=(10,), store=wrapped, overwrite=True)
arr[:] = 1

assert wrapped.counter["set"] == 2
assert wrapped.counter["get"] == 0 # 1 if overwrite=False
assert wrapped.counter["list"] == 0
assert wrapped.counter["list_dir"] == 0
assert wrapped.counter["list_prefix"] == 0