diff --git a/rock_paper_sand/justwatch.py b/rock_paper_sand/justwatch.py index 04c1a5d..374315b 100644 --- a/rock_paper_sand/justwatch.py +++ b/rock_paper_sand/justwatch.py @@ -197,7 +197,7 @@ def _content_number(content: Any) -> multi_level_set.MultiLevelNumber: return multi_level_set.MultiLevelNumber(tuple(parts)) -class Filter(media_filter.Filter): +class Filter(media_filter.CachedFilter): """Filter based on JustWatch's API.""" def __init__( @@ -206,6 +206,7 @@ def __init__( *, api: Api, ) -> None: + super().__init__() self._config = filter_config self._api = api @@ -292,7 +293,9 @@ def _all_done( return False return True - def filter(self, item: media_item.MediaItem) -> media_filter.FilterResult: + def filter_implementation( + self, item: media_item.MediaItem + ) -> media_filter.FilterResult: """See base class.""" now = datetime.datetime.now(tz=datetime.timezone.utc) if not item.proto.justwatch_id: diff --git a/rock_paper_sand/media_filter.py b/rock_paper_sand/media_filter.py index 14e5190..4cc5879 100644 --- a/rock_paper_sand/media_filter.py +++ b/rock_paper_sand/media_filter.py @@ -48,6 +48,27 @@ def filter(self, item: media_item.MediaItem) -> FilterResult: raise NotImplementedError() +class CachedFilter(Filter, abc.ABC): + """Base class for filters that cache their results. + + Child classes should override filter_implementation() instead of filter(). + """ + + def __init__(self) -> None: + self._result_by_id: dict[str, FilterResult] = {} + + @abc.abstractmethod + def filter_implementation(self, item: media_item.MediaItem) -> FilterResult: + """See Filter.filter.""" + raise NotImplementedError() + + def filter(self, item: media_item.MediaItem) -> FilterResult: + """See base class.""" + if item.id not in self._result_by_id: + self._result_by_id[item.id] = self.filter_implementation(item) + return self._result_by_id[item.id] + + class Not(Filter): """Inverts another filter.""" diff --git a/rock_paper_sand/media_filter_test.py b/rock_paper_sand/media_filter_test.py index 680def8..7f88019 100644 --- a/rock_paper_sand/media_filter_test.py +++ b/rock_paper_sand/media_filter_test.py @@ -28,12 +28,23 @@ from rock_paper_sand.proto import config_pb2 -class _ExtraInfoFilter(media_filter.Filter): +class _ExtraInfoFilter(media_filter.CachedFilter): + """Test filter that returns the given extra info. + + Attributes: + call_count: Number of times filter_implementation() was called. + """ + def __init__(self, extra: Set[str]) -> None: + super().__init__() self._extra = extra + self.call_count = 0 - def filter(self, item: media_item.MediaItem) -> media_filter.FilterResult: + def filter_implementation( + self, item: media_item.MediaItem + ) -> media_filter.FilterResult: """See base class.""" + self.call_count += 1 return media_filter.FilterResult(True, extra=self._extra) @@ -231,6 +242,18 @@ def test_basic_filter( result = test_filter.filter(media_item.MediaItem.from_config(item)) self.assertEqual(expected_result, result) + def test_cached_filter(self) -> None: + test_filter = _ExtraInfoFilter({"foo"}) + item = media_item.MediaItem.from_config(config_pb2.MediaItem()) + expected_result = media_filter.FilterResult(True, extra={"foo"}) + + first_result = test_filter.filter(item) + second_result = test_filter.filter(item) + + self.assertEqual(expected_result, first_result) + self.assertEqual(expected_result, second_result) + self.assertEqual(1, test_filter.call_count) + def test_justwatch_filter(self) -> None: mock_filter = mock.create_autospec( media_filter.Filter, spec_set=True, instance=True diff --git a/rock_paper_sand/media_item.py b/rock_paper_sand/media_item.py index 74cde02..7a4a2c8 100644 --- a/rock_paper_sand/media_item.py +++ b/rock_paper_sand/media_item.py @@ -16,6 +16,7 @@ from collections.abc import Sequence import dataclasses from typing import Self +import uuid from rock_paper_sand import multi_level_set from rock_paper_sand.proto import config_pb2 @@ -26,11 +27,18 @@ class MediaItem: """Media item. Attributes: + id: Unique ID of the media item. This is not stable across runs of the + program, so it should not be stored anywhere or shown to the user. + It's designed for caching filter results in memory. proto: Proto from the config file. done: Parsed proto.done field. parts: Parsed proto.parts field. """ + id: str = dataclasses.field( + default_factory=lambda: str(uuid.uuid4()), + repr=False, + ) proto: config_pb2.MediaItem done: multi_level_set.MultiLevelSet parts: Sequence["MediaItem"] diff --git a/rock_paper_sand/media_item_test.py b/rock_paper_sand/media_item_test.py index 61bad39..e76b478 100644 --- a/rock_paper_sand/media_item_test.py +++ b/rock_paper_sand/media_item_test.py @@ -42,10 +42,12 @@ def test_from_config(self) -> None: self.assertEqual( media_item.MediaItem( + id=mock.ANY, proto=proto, done=mock.ANY, parts=( media_item.MediaItem( + id=mock.ANY, proto=config_pb2.MediaItem(name="some-part"), done=mock.ANY, parts=(), @@ -57,6 +59,12 @@ def test_from_config(self) -> None: self.assertIn(multi_level_set.parse_number("1"), item.done) self.assertNotIn(multi_level_set.parse_number("1"), item.parts[0].done) + def test_id(self) -> None: + self.assertNotEqual( + media_item.MediaItem.from_config(config_pb2.MediaItem()), + media_item.MediaItem.from_config(config_pb2.MediaItem()), + ) + if __name__ == "__main__": absltest.main()