Skip to content

Commit

Permalink
Use threadsafe cache with pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Apr 17, 2024
1 parent 7dbc68a commit 4b508a1
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 5 deletions.
34 changes: 34 additions & 0 deletions src/saturn_engine/utils/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import typing as t

import functools
import threading
from collections import defaultdict


class _CacheLock:
__slots__ = ("lock",)

def __init__(self) -> None:
self.lock = threading.Lock()


def threadsafe_cache(func: t.Callable) -> t.Callable:
"""
Same as functools, but with a lock to ensure function are called only
once per key.
"""

cache: dict = defaultdict(_CacheLock)

def wrapper(*args: t.Any, **kwargs: t.Any) -> None:
key = functools._make_key(args, kwargs, typed=False)
result = cache[key]
if isinstance(result, _CacheLock):
with result.lock:
result = cache[key]
if isinstance(result, _CacheLock):
result = func(*args, **kwargs)
cache[key] = result
return result

return wrapper
7 changes: 4 additions & 3 deletions src/saturn_engine/utils/inspect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import typing as t

import dataclasses
import functools
import importlib
import inspect
import sys
Expand All @@ -10,6 +9,8 @@

import typing_inspect

from .cache import threadsafe_cache

R = t.TypeVar("R")


Expand Down Expand Up @@ -143,7 +144,7 @@ def import_name(name: str) -> t.Any:
raise ModuleNotFoundError(name)


@functools.cache
@threadsafe_cache
def signature(func: t.Callable) -> inspect.Signature:
_signature = inspect.signature(func)
_signature = eval_annotations(func, _signature)
Expand Down Expand Up @@ -190,7 +191,7 @@ def call(self, *, kwargs: t.Optional[dict[str, t.Any]] = None) -> R:
return self._func(**args_dict, **kwargs)


@functools.cache
@threadsafe_cache
def dataclass_from_params(func: t.Callable[..., R]) -> t.Type[BaseParamsDataclass[R]]:
cls_name = func.__name__ + ".params"
fields: list[tuple] = []
Expand Down
5 changes: 3 additions & 2 deletions src/saturn_engine/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import dataclasses
import json
from abc import abstractmethod
from functools import cache

import pydantic.v1
import pydantic.v1.json

from .cache import threadsafe_cache

OptionsSchemaT = t.TypeVar("OptionsSchemaT", bound="OptionsSchema")
T = t.TypeVar("T")

Expand All @@ -32,7 +33,7 @@ class ModelConfig(pydantic.v1.BaseConfig):
arbitrary_types_allowed = True


@cache
@threadsafe_cache
def schema_for(klass: t.Type) -> t.Type[pydantic.v1.BaseModel]:
if issubclass(klass, pydantic.v1.BaseModel):
return klass
Expand Down
56 changes: 56 additions & 0 deletions tests/utils/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import threading
from unittest.mock import Mock
from unittest.mock import call

from saturn_engine.utils.cache import threadsafe_cache


def test_threadsafe_lock() -> None:
in_func = threading.Event()
out_func = threading.Event()
spy = Mock()
results = []

@threadsafe_cache
def func(x: int) -> int:
spy(x)
in_func.set()

# only block for x==1, so we can test other x value
if x == 1:
out_func.wait()
return x

def check_func(x: int) -> None:
results.append(func(x))

t1 = threading.Thread(target=check_func, args=(1,))
t1.daemon = True
t1.start()

# Wait until t1 is in the cached function.
in_func.wait()
assert t1.is_alive()

in_func.clear()
t2 = threading.Thread(target=check_func, args=(1,))
t2.daemon = True
t2.start()
# Give some time for t2 to reach the cache lock.
assert not in_func.wait(0.1)

# Using a different key won't block on the same lock.
assert func(2) == 2

# Unblock t1
out_func.set()
t1.join()
t2.join()

# Both threads returned the x value.
assert results == [1, 1]
# It now use the cached value (in_func event is still blocking)
assert func(1) == 1

# Each key should only have been called once.
assert spy.call_args_list == [call(1), call(2)]

0 comments on commit 4b508a1

Please sign in to comment.