Skip to content

Commit

Permalink
New tests and feature for override providers by names in kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan committed Aug 3, 2024
1 parent d5b757e commit c156143
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 18 deletions.
18 changes: 17 additions & 1 deletion src/injection/base_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ def _get_providers_generator(cls) -> Iterator[BaseProvider]:
def get_providers(cls) -> List[BaseProvider]:
return list(cls.__get_providers().values())

@classmethod
@contextmanager
def override_providers_kwargs(
cls,
*,
reset_singletons: bool = False,
**providers_for_overriding,
) -> Iterator[None]:
with cls.override_providers(
providers_for_overriding,
reset_singletons=reset_singletons,
):
yield

@classmethod
@contextmanager
def override_providers(
Expand All @@ -55,6 +69,7 @@ def override_providers(
msg = f"Provider with name {given_name!r} not found"
raise RuntimeError(msg)

# Reset singletons that which were resolved BEFORE the current context
if reset_singletons:
cls.reset_singletons()

Expand All @@ -68,6 +83,7 @@ def override_providers(
provider = current_providers[provider_name]
provider.reset_override()

# Reset singletons that which were resolved INSIDE the current context
if reset_singletons:
cls.reset_singletons()

Expand All @@ -77,7 +93,7 @@ def reset_singletons(cls) -> None:

for provider in providers_gen:
if isinstance(provider, Singleton):
provider.reset_cache()
provider.reset()

@classmethod
def reset_override(cls) -> None:
Expand Down
4 changes: 1 addition & 3 deletions src/injection/providers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ def _resolve(self, *args, **kwargs) -> T:
Positional arguments are appended after Factory positional dependencies.
Keyword arguments have the priority over the Factory keyword dependencies with the same name.
"""
if args or kwargs:
self.reset_cache()

if self._instance is None:
self._instance = super()._resolve(*args, **kwargs)

return self._instance

def reset_cache(self):
def reset(self):
self._instance = None
5 changes: 1 addition & 4 deletions tests/test_base_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ def test_override_providers_success(container):
mock_redis.get.return_value = -999
nested_override_objects = {"redis": mock_redis, "num": 92934}

with container.override_providers(
nested_override_objects,
reset_singletons=False,
):
with container.override_providers(nested_override_objects):
assert container.num() == 92934
assert container.redis().get() == -999

Expand Down
17 changes: 11 additions & 6 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from unittest.mock import Mock

from tests.container_objects import func_with_injections
from tests.container_objects import Settings, func_with_injections


def test_e2e_success(container):
Expand Down Expand Up @@ -45,17 +45,13 @@ class MockSettings:
mock_settings = MockSettings(redis_url=mock_url)
providers_for_overriding = {"settings": mock_settings}

# container.reset_singletons()
# res = container.callable_test_1(d='sfs')
with container.override_providers(providers_for_overriding, reset_singletons=True):
redis_url = func_with_injections(2, ddd="sfs")
assert mock_url == redis_url

# container.reset_singletons()
redis_url = func_with_injections(2, ddd="sfs")
assert redis_url == "redis://localhost"

# mock redis
mock_redis = Mock()
mock_redis.url = "mock_url_tests"
mock_redis.get.return_value = None
Expand All @@ -64,5 +60,14 @@ class MockSettings:
_ = container.redis()
_ = container.settings()

with container.override_providers(providers_for_overriding):
with container.override_providers(providers_for_overriding, reset_singletons=True):
assert mock_redis.url == func_with_injections(2, ddd="sfs")

mock_settings = Settings(redis_url="mock_redis_url_2")

with container.override_providers_kwargs(
settings=mock_settings,
reset_singletons=True,
):
assert container.redis().url == "mock_redis_url_2"
assert func_with_injections(2, ddd="sfs") == "mock_redis_url_2"
26 changes: 22 additions & 4 deletions tests/test_providers/test_singleton.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass

import pytest
from injection import providers


Expand All @@ -16,7 +17,7 @@ def test_singleton_override_args():
assert resolved.field1 == "new_value"
assert resolved.field2 == 100

singleton_provider.reset_cache()
singleton_provider.reset()
resolved = singleton_provider()
assert resolved.field1 == "value"
assert resolved.field2 == 1
Expand All @@ -29,6 +30,23 @@ def test_singleton_resolving_with_override_params_no_work_without_reset_cache():
assert resolved.field1 == "new_value"
assert resolved.field2 == 100

resolved = singleton_provider()
assert resolved.field1 == "new_value"
assert resolved.field2 == 100
resolved = singleton_provider(field1="override_value", field2=239)

with pytest.raises(AssertionError):
assert resolved.field1 == "override_value"

with pytest.raises(AssertionError):
assert resolved.field2 == 239


def test_singleton_reset_smoke():
provider = providers.Singleton(SomeClass, field1="...", field2=-9000)
obj = provider()
obj2 = provider()

assert obj is obj2
provider.reset()

obj3 = provider()
assert obj is not obj3
assert obj2 is not obj3

0 comments on commit c156143

Please sign in to comment.