Skip to content

Commit

Permalink
Cache JustWatch filter results
Browse files Browse the repository at this point in the history
With a hot requests cache, this reduced the time to print one my reports
from about 6 seconds to about 5 seconds.
  • Loading branch information
dseomn committed Sep 24, 2023
1 parent 2e30458 commit d273a20
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 4 deletions.
7 changes: 5 additions & 2 deletions rock_paper_sand/justwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _content_number(content: Any) -> multi_level_set.MultiLevelNumber:
return multi_level_set.MultiLevelNumber(tuple(parts))


class Filter(media_filter.Filter):
class Filter(media_filter.CachedFilter):
"""Filter based on JustWatch's API."""

def __init__(
Expand All @@ -206,6 +206,7 @@ def __init__(
*,
api: Api,
) -> None:
super().__init__()
self._config = filter_config
self._api = api

Expand Down Expand Up @@ -292,7 +293,9 @@ def _all_done(
return False
return True

def filter(self, item: media_item.MediaItem) -> media_filter.FilterResult:
def filter_implementation(
self, item: media_item.MediaItem
) -> media_filter.FilterResult:
"""See base class."""
now = datetime.datetime.now(tz=datetime.timezone.utc)
if not item.proto.justwatch_id:
Expand Down
21 changes: 21 additions & 0 deletions rock_paper_sand/media_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ def filter(self, item: media_item.MediaItem) -> FilterResult:
raise NotImplementedError()


class CachedFilter(Filter, abc.ABC):
"""Base class for filters that cache their results.
Child classes should override filter_implementation() instead of filter().
"""

def __init__(self) -> None:
self._result_by_id: dict[str, FilterResult] = {}

@abc.abstractmethod
def filter_implementation(self, item: media_item.MediaItem) -> FilterResult:
"""See Filter.filter."""
raise NotImplementedError()

def filter(self, item: media_item.MediaItem) -> FilterResult:
"""See base class."""
if item.id not in self._result_by_id:
self._result_by_id[item.id] = self.filter_implementation(item)
return self._result_by_id[item.id]


class Not(Filter):
"""Inverts another filter."""

Expand Down
27 changes: 25 additions & 2 deletions rock_paper_sand/media_filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,23 @@
from rock_paper_sand.proto import config_pb2


class _ExtraInfoFilter(media_filter.Filter):
class _ExtraInfoFilter(media_filter.CachedFilter):
"""Test filter that returns the given extra info.
Attributes:
call_count: Number of times filter_implementation() was called.
"""

def __init__(self, extra: Set[str]) -> None:
super().__init__()
self._extra = extra
self.call_count = 0

def filter(self, item: media_item.MediaItem) -> media_filter.FilterResult:
def filter_implementation(
self, item: media_item.MediaItem
) -> media_filter.FilterResult:
"""See base class."""
self.call_count += 1
return media_filter.FilterResult(True, extra=self._extra)


Expand Down Expand Up @@ -231,6 +242,18 @@ def test_basic_filter(
result = test_filter.filter(media_item.MediaItem.from_config(item))
self.assertEqual(expected_result, result)

def test_cached_filter(self) -> None:
test_filter = _ExtraInfoFilter({"foo"})
item = media_item.MediaItem.from_config(config_pb2.MediaItem())
expected_result = media_filter.FilterResult(True, extra={"foo"})

first_result = test_filter.filter(item)
second_result = test_filter.filter(item)

self.assertEqual(expected_result, first_result)
self.assertEqual(expected_result, second_result)
self.assertEqual(1, test_filter.call_count)

def test_justwatch_filter(self) -> None:
mock_filter = mock.create_autospec(
media_filter.Filter, spec_set=True, instance=True
Expand Down
8 changes: 8 additions & 0 deletions rock_paper_sand/media_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Sequence
import dataclasses
from typing import Self
import uuid

from rock_paper_sand import multi_level_set
from rock_paper_sand.proto import config_pb2
Expand All @@ -26,11 +27,18 @@ class MediaItem:
"""Media item.
Attributes:
id: Unique ID of the media item. This is not stable across runs of the
program, so it should not be stored anywhere or shown to the user.
It's designed for caching filter results in memory.
proto: Proto from the config file.
done: Parsed proto.done field.
parts: Parsed proto.parts field.
"""

id: str = dataclasses.field(
default_factory=lambda: str(uuid.uuid4()),
repr=False,
)
proto: config_pb2.MediaItem
done: multi_level_set.MultiLevelSet
parts: Sequence["MediaItem"]
Expand Down
8 changes: 8 additions & 0 deletions rock_paper_sand/media_item_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def test_from_config(self) -> None:

self.assertEqual(
media_item.MediaItem(
id=mock.ANY,
proto=proto,
done=mock.ANY,
parts=(
media_item.MediaItem(
id=mock.ANY,
proto=config_pb2.MediaItem(name="some-part"),
done=mock.ANY,
parts=(),
Expand All @@ -57,6 +59,12 @@ def test_from_config(self) -> None:
self.assertIn(multi_level_set.parse_number("1"), item.done)
self.assertNotIn(multi_level_set.parse_number("1"), item.parts[0].done)

def test_id(self) -> None:
self.assertNotEqual(
media_item.MediaItem.from_config(config_pb2.MediaItem()),
media_item.MediaItem.from_config(config_pb2.MediaItem()),
)


if __name__ == "__main__":
absltest.main()

0 comments on commit d273a20

Please sign in to comment.