diff --git a/README.md b/README.md index a675a8c..14a40d0 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,10 @@ You can find the source code for `py-dependency-injection` on [GitHub](https://g ## Release Notes +### [1.0.0-alpha.10](https://github.com/runemalm/py-dependency-injection/releases/tag/v1.0.0-alpha.10) (2024-08-11) + +- **Tagged Constructor Injection**: Introduced support for constructor injection using the `Tagged`, `AnyTagged`, and `AllTagged` classes. This allows for seamless injection of dependencies that have been registered with specific tags, enhancing flexibility and control in managing your application's dependencies. + ### [1.0.0-alpha.9](https://github.com/runemalm/py-dependency-injection/releases/tag/v1.0.0-alpha.9) (2024-08-08) - **Breaking Change**: Removed constructor injection when resolving dataclasses. diff --git a/docs/conf.py b/docs/conf.py index 9f3163c..7329242 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -35,7 +35,7 @@ version = "1.0" # The full version, including alpha/beta/rc tags -release = "1.0.0-alpha.9" +release = "1.0.0-alpha.10" # -- General configuration --------------------------------------------------- diff --git a/docs/examples.rst b/docs/examples.rst index 1346e04..908e923 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -225,6 +225,56 @@ This example illustrates how to use constructor injection to automatically injec print(repository.connection.__class__.__name__) # Output: PostgresConnection +###################################################### +Using constructor injection with tagged dependencies +###################################################### + +This example demonstrates how to use constructor injection to automatically inject tagged dependencies into your classes. By leveraging tags, you can group and categorize dependencies, enabling automatic injection based on specific criteria. + +.. code-block:: python + + class PrimaryPort: + pass + + class SecondaryPort: + pass + + class HttpAdapter(PrimaryPort): + pass + + class PostgresCarRepository(SecondaryPort): + pass + + class Application: + def __init__(self, primary_ports: List[Tagged[PrimaryPort]], secondary_ports: List[Tagged[SecondaryPort]]): + self.primary_ports = primary_ports + self.secondary_ports = secondary_ports + + # Register dependencies with tags + dependency_container.register_transient(HttpAdapter, tags={PrimaryPort}) + dependency_container.register_transient(PostgresCarRepository, tags={SecondaryPort}) + + # Register the Application class to have its dependencies injected + dependency_container.register_transient(Application) + + # Resolve the Application class, with tagged dependencies automatically injected + application = dependency_container.resolve(Application) + + # Use the injected dependencies + print(f"Primary ports: {len(application.primary_ports)}") # Output: Primary ports: 1 + print(f"Secondary ports: {len(application.secondary_ports)}") # Output: Secondary ports: 1 + print(f"Primary port instance: {type(application.primary_ports[0]).__name__}") # Output: HttpAdapter + print(f"Secondary port instance: {type(application.secondary_ports[0]).__name__}") # Output: PostgresCarRepository + + +In this example, the `Application` class expects lists of instances tagged with `PrimaryPort` and `SecondaryPort`. By tagging and registering these dependencies, the container automatically injects the correct instances into the `Application` class when it is resolved. + +Tags offer a powerful way to manage dependencies, ensuring that the right instances are injected based on your application's needs. + +.. note:: + You can also use the ``AnyTagged`` and ``AllTagged`` classes to inject dependencies based on more complex tagging logic. ``AnyTagged`` allows injection of any dependency matching one or more specified tags, while ``AllTagged`` requires the dependency to match all specified tags before injection. This provides additional flexibility in managing and resolving dependencies in your application. + + ###################### Using method injection ###################### diff --git a/docs/releases.rst b/docs/releases.rst index 1e22003..03c715b 100644 --- a/docs/releases.rst +++ b/docs/releases.rst @@ -6,6 +6,12 @@ Version History ############### +**1.0.0-alpha.10 (2024-08-11)** + +- **Tagged Constructor Injection**: Introduced support for constructor injection using the `Tagged`, `AnyTagged`, and `AllTagged` classes. This allows for seamless injection of dependencies that have been registered with specific tags, enhancing flexibility and control in managing your application's dependencies. + +`View release on GitHub `_ + **1.0.0-alpha.9 (2024-08-08)** - **Breaking Change**: Removed constructor injection when resolving dataclasses. diff --git a/setup.py b/setup.py index 2796357..809a8aa 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="py-dependency-injection", - version="1.0.0-alpha.9", + version="1.0.0-alpha.10", author="David Runemalm, 2024", author_email="david.runemalm@gmail.com", description="A dependency injection library for Python.", diff --git a/src/dependency_injection/container.py b/src/dependency_injection/container.py index ee7defa..ab035f4 100644 --- a/src/dependency_injection/container.py +++ b/src/dependency_injection/container.py @@ -3,6 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, TypeVar, Type +from dependency_injection.tags.all_tagged import AllTagged +from dependency_injection.tags.any_tagged import AnyTagged +from dependency_injection.tags.tagged import Tagged from dependency_injection.registration import Registration from dependency_injection.scope import DEFAULT_SCOPE_NAME, Scope from dependency_injection.utils.singleton_meta import SingletonMeta @@ -23,8 +26,7 @@ def __init__(self, name: str = None): @classmethod def get_instance(cls, name: str = None) -> Self: - if name is None: - name = DEFAULT_CONTAINER_NAME + name = name or DEFAULT_CONTAINER_NAME if (cls, name) not in cls._instances: cls._instances[(cls, name)] = cls(name) @@ -48,11 +50,7 @@ def register_transient( tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None, ) -> None: - if implementation is None: - implementation = dependency - if dependency in self._registrations: - raise ValueError(f"Dependency {dependency} is already registered.") - self._registrations[dependency] = Registration( + self._register( dependency, implementation, Scope.TRANSIENT, tags, constructor_args ) @@ -63,13 +61,7 @@ def register_scoped( tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None, ) -> None: - if implementation is None: - implementation = dependency - if dependency in self._registrations: - raise ValueError(f"Dependency {dependency} is already registered.") - self._registrations[dependency] = Registration( - dependency, implementation, Scope.SCOPED, tags, constructor_args - ) + self._register(dependency, implementation, Scope.SCOPED, tags, constructor_args) def register_singleton( self, @@ -78,11 +70,7 @@ def register_singleton( tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None, ) -> None: - if implementation is None: - implementation = dependency - if dependency in self._registrations: - raise ValueError(f"Dependency {dependency} is already registered.") - self._registrations[dependency] = Registration( + self._register( dependency, implementation, Scope.SINGLETON, tags, constructor_args ) @@ -93,73 +81,104 @@ def register_factory( factory_args: Optional[Dict[str, Any]] = None, tags: Optional[set] = None, ) -> None: - if dependency in self._registrations: - raise ValueError(f"Dependency {dependency} is already registered.") + self._validate_registration(dependency) self._registrations[dependency] = Registration( - dependency, None, Scope.FACTORY, None, tags, factory, factory_args + dependency, None, Scope.FACTORY, tags, None, factory, factory_args ) def register_instance( self, dependency: Type, instance: Any, tags: Optional[set] = None ) -> None: - if dependency in self._registrations: - raise ValueError(f"Dependency {dependency} is already registered.") + self._validate_registration(dependency) self._registrations[dependency] = Registration( - dependency, type(instance), Scope.SINGLETON, constructor_args={}, tags=tags + dependency, type(instance), Scope.SINGLETON, tags=tags ) self._singleton_instances[dependency] = instance - def resolve(self, dependency: Type, scope_name: str = DEFAULT_SCOPE_NAME) -> Type: + def _register( + self, + dependency: Type, + implementation: Optional[Type], + scope: Scope, + tags: Optional[set], + constructor_args: Optional[Dict[str, Any]], + ) -> None: + implementation = implementation or dependency + self._validate_registration(dependency) + self._registrations[dependency] = Registration( + dependency, implementation, scope, tags, constructor_args + ) + + def resolve(self, dependency: Type, scope_name: str = DEFAULT_SCOPE_NAME) -> Any: self._has_resolved = True if scope_name not in self._scoped_instances: self._scoped_instances[scope_name] = {} - if dependency not in self._registrations: + registration = self._registrations.get(dependency) + if not registration: raise KeyError(f"Dependency {dependency.__name__} is not registered.") - registration = self._registrations[dependency] - scope = registration.scope - implementation = registration.implementation - constructor_args = registration.constructor_args + constructor_args = registration.constructor_args or {} + self._validate_constructor_args(constructor_args, registration.implementation) - self._validate_constructor_args( - constructor_args=constructor_args, implementation=implementation - ) + return self._resolve_by_scope(registration, scope_name) + + def _resolve_by_scope(self, registration: Registration, scope_name: str) -> Any: + scope = registration.scope if scope == Scope.TRANSIENT: return self._inject_dependencies( - implementation=implementation, constructor_args=constructor_args + registration.implementation, + constructor_args=registration.constructor_args, ) elif scope == Scope.SCOPED: - if dependency not in self._scoped_instances[scope_name]: - self._scoped_instances[scope_name][ - dependency - ] = self._inject_dependencies( - implementation=implementation, - scope_name=scope_name, - constructor_args=constructor_args, + instances = self._scoped_instances[scope_name] + if registration.dependency not in instances: + instances[registration.dependency] = self._inject_dependencies( + registration.implementation, + scope_name, + registration.constructor_args, ) - return self._scoped_instances[scope_name][dependency] + return instances[registration.dependency] elif scope == Scope.SINGLETON: - if dependency not in self._singleton_instances: - self._singleton_instances[dependency] = self._inject_dependencies( - implementation=implementation, constructor_args=constructor_args + if registration.dependency not in self._singleton_instances: + self._singleton_instances[ + registration.dependency + ] = self._inject_dependencies( + registration.implementation, + constructor_args=registration.constructor_args, ) - return self._singleton_instances[dependency] + return self._singleton_instances[registration.dependency] elif scope == Scope.FACTORY: - factory = registration.factory - factory_args = registration.factory_args or {} - return factory(**factory_args) + return registration.factory(**(registration.factory_args or {})) raise ValueError(f"Invalid dependency scope: {scope}") - def resolve_all(self, tags: Optional[set] = None) -> List[Any]: - tags = tags or [] + def resolve_all( + self, tags: Optional[set] = None, match_all_tags: bool = False + ) -> List[Any]: + tags = tags or set() resolved_dependencies = [] + for registration in self._registrations.values(): - if not len(tags) or tags.intersection(registration.tags): + if not tags: + # If no tags are provided, resolve all dependencies resolved_dependencies.append(self.resolve(registration.dependency)) + else: + if match_all_tags: + # Match dependencies that have all the specified tags + if registration.tags and tags.issubset(registration.tags): + resolved_dependencies.append( + self.resolve(registration.dependency) + ) + else: + # Match dependencies that have any of the specified tags + if registration.tags and tags.intersection(registration.tags): + resolved_dependencies.append( + self.resolve(registration.dependency) + ) + return resolved_dependencies def _validate_constructor_args( @@ -184,6 +203,10 @@ def _validate_constructor_args( f"provided type: {type(arg_value)}." ) + def _validate_registration(self, dependency: Type) -> None: + if dependency in self._registrations: + raise ValueError(f"Dependency {dependency} is already registered.") + def _inject_dependencies( self, implementation: Type, @@ -199,20 +222,56 @@ def _inject_dependencies( dependencies = {} for param_name, param_info in params.items(): if param_name != "self": - # Check for *args and **kwargs if param_info.kind == inspect.Parameter.VAR_POSITIONAL: - # *args parameter pass elif param_info.kind == inspect.Parameter.VAR_KEYWORD: - # **kwargs parameter pass else: - # Check if constructor_args has an argument with the same name if constructor_args and param_name in constructor_args: dependencies[param_name] = constructor_args[param_name] else: - dependencies[param_name] = self.resolve( - param_info.annotation, scope_name=scope_name - ) + if ( + hasattr(param_info.annotation, "__origin__") + and param_info.annotation.__origin__ is list + ): + inner_type = param_info.annotation.__args__[0] + + tagged_dependencies = [] + if isinstance(inner_type, type) and issubclass( + inner_type, Tagged + ): + tagged_type = inner_type.tag + tagged_dependencies = self.resolve_all( + tags={tagged_type} + ) + + elif isinstance(inner_type, type) and issubclass( + inner_type, AnyTagged + ): + tagged_dependencies = self.resolve_all( + tags=inner_type.tags, match_all_tags=False + ) + + elif isinstance(inner_type, type) and issubclass( + inner_type, AllTagged + ): + tagged_dependencies = self.resolve_all( + tags=inner_type.tags, match_all_tags=True + ) + + dependencies[param_name] = tagged_dependencies + + else: + try: + dependencies[param_name] = self.resolve( + param_info.annotation, scope_name=scope_name + ) + except KeyError: + raise ValueError( + f"Cannot resolve dependency for parameter " + f"'{param_name}' of type " + f"'{param_info.annotation}' in class " + f"'{implementation.__name__}'." + ) return implementation(**dependencies) diff --git a/src/dependency_injection/tags/__init__.py b/src/dependency_injection/tags/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dependency_injection/tags/all_tagged.py b/src/dependency_injection/tags/all_tagged.py new file mode 100644 index 0000000..18964a6 --- /dev/null +++ b/src/dependency_injection/tags/all_tagged.py @@ -0,0 +1,20 @@ +from typing import Generic, Set, Tuple, Type, TypeVar, Union + +T = TypeVar("T") + + +class AllTagged(Generic[T]): + def __init__(self, tags: Tuple[Type[T], ...]): + self.tags: Set[Type[T]] = set(tags) + + @classmethod + def __class_getitem__( + cls, item: Union[Type[T], Tuple[Type[T], ...]] + ) -> Type["AllTagged"]: + if not isinstance(item, tuple): + item = (item,) + return type( + f'AllTagged_{"_".join([t.__name__ for t in item])}', + (cls,), + {"tags": set(item)}, + ) diff --git a/src/dependency_injection/tags/any_tagged.py b/src/dependency_injection/tags/any_tagged.py new file mode 100644 index 0000000..b8be003 --- /dev/null +++ b/src/dependency_injection/tags/any_tagged.py @@ -0,0 +1,22 @@ +from typing import Type, Generic, TypeVar, Tuple, Union, Set + +T = TypeVar("T") + + +class AnyTagged(Generic[T]): + def __init__(self, tags: Union[Tuple[Type[T], ...], Type[T]]): + if not isinstance(tags, tuple): + tags = (tags,) + self.tags: Set[Type[T]] = set(tags) + + @classmethod + def __class_getitem__( + cls, item: Union[Type[T], Tuple[Type[T], ...]] + ) -> Type["AnyTagged"]: + if not isinstance(item, tuple): + item = (item,) + return type( + f'AnyTagged_{"_".join([t.__name__ for t in item])}', + (cls,), + {"tags": set(item)}, + ) diff --git a/src/dependency_injection/tags/tagged.py b/src/dependency_injection/tags/tagged.py new file mode 100644 index 0000000..5064efb --- /dev/null +++ b/src/dependency_injection/tags/tagged.py @@ -0,0 +1,12 @@ +from typing import Type, Generic, TypeVar + +T = TypeVar("T") + + +class Tagged(Generic[T]): + def __init__(self, tag: Type[T]): + self.tag = tag + + @classmethod + def __class_getitem__(cls, item: Type[T]) -> Type["Tagged"]: + return type(f"Tagged_{item.__name__}", (cls,), {"tag": item}) diff --git a/tests/unit_test/container/resolve/test_resolve_all.py b/tests/unit_test/container/resolve/test_resolve_all.py index 414d5f3..1b30c48 100644 --- a/tests/unit_test/container/resolve/test_resolve_all.py +++ b/tests/unit_test/container/resolve/test_resolve_all.py @@ -65,6 +65,44 @@ class Innovation: ) ) + def test_returns_only_dependencies_matching_all_tags_when_match_all_tags_is_true( + self, + ): + # arrange + class Driveable: + pass + + class Transporting: + pass + + class Vehicle: + pass + + class Car(Vehicle): + pass + + class Innovation: + pass + + dependency_container = DependencyContainer.get_instance() + dependency_container.register_transient(Vehicle, tags={Driveable, Transporting}) + dependency_container.register_transient(Car, tags={Driveable, Transporting}) + dependency_container.register_transient(Innovation, tags={Driveable}) + + # act + resolved_dependencies = dependency_container.resolve_all( + tags={Driveable, Transporting}, match_all_tags=True + ) + + # assert + self.assertEqual(len(resolved_dependencies), 2) + self.assertTrue( + any(isinstance(dependency, Vehicle) for dependency in resolved_dependencies) + ) + self.assertTrue( + any(isinstance(dependency, Car) for dependency in resolved_dependencies) + ) + def test_does_not_return_dependency_without_tag( self, ): diff --git a/tests/unit_test/container/resolve/test_resolve_with_injection.py b/tests/unit_test/container/resolve/test_resolve_with_injection.py index f8a10f3..ae23cfb 100644 --- a/tests/unit_test/container/resolve/test_resolve_with_injection.py +++ b/tests/unit_test/container/resolve/test_resolve_with_injection.py @@ -1,6 +1,10 @@ from dataclasses import dataclass +from typing import List from dependency_injection.container import DependencyContainer +from dependency_injection.tags.all_tagged import AllTagged +from dependency_injection.tags.any_tagged import AnyTagged +from dependency_injection.tags.tagged import Tagged from unit_test.unit_test_case import UnitTestCase @@ -49,3 +53,157 @@ class Car: self.assertIsNone( resolved_dependency.engine ) # Should be None since injection is skipped + + def test_resolve_injects_tagged_dependencies(self): + # arrange + class HttpAdapter: + pass + + class CliAdapter: + pass + + class PrimaryPort: + pass + + class PostgresCarRepository: + pass + + class SecondaryPort: + pass + + class Application: + def __init__(self, primary_ports: List[Tagged[PrimaryPort]]): + self.primary_ports = primary_ports + + dependency_container = DependencyContainer.get_instance() + dependency_container.register_transient(HttpAdapter, tags={PrimaryPort}) + dependency_container.register_transient(CliAdapter, tags={PrimaryPort}) + dependency_container.register_transient( + PostgresCarRepository, tags={SecondaryPort} + ) + dependency_container.register_transient(Application) + + # act + resolved_dependency = dependency_container.resolve(Application) + + # assert + self.assertIsInstance(resolved_dependency, Application) + self.assertEqual(len(resolved_dependency.primary_ports), 2) + self.assertTrue( + any(isinstance(t, HttpAdapter) for t in resolved_dependency.primary_ports) + ) + self.assertTrue( + any(isinstance(t, CliAdapter) for t in resolved_dependency.primary_ports) + ) + self.assertFalse( + any( + isinstance(t, PostgresCarRepository) + for t in resolved_dependency.primary_ports + ) + ) + + def test_resolve_injects_any_tagged_dependencies(self): + # arrange + class Volvo: + pass + + class Scania: + pass + + class Car: + pass + + class Truck: + pass + + class Fruit: + pass + + class Banana: + pass + + class Trip: + def __init__(self, transportations: List[AnyTagged[Car, Truck]]): + self.transportations = transportations + + dependency_container = DependencyContainer.get_instance() + dependency_container.register_transient(Volvo, tags={Car}) + dependency_container.register_transient(Scania, tags={Truck}) + dependency_container.register_transient(Banana, tags={Fruit}) + dependency_container.register_transient(Trip) + + # act + resolved_dependency = dependency_container.resolve(Trip) + + # assert + self.assertIsInstance(resolved_dependency, Trip) + self.assertIsInstance(resolved_dependency, Trip) + self.assertEqual(len(resolved_dependency.transportations), 2) + self.assertTrue( + any(isinstance(t, Volvo) for t in resolved_dependency.transportations) + ) + self.assertTrue( + any(isinstance(t, Scania) for t in resolved_dependency.transportations) + ) + self.assertFalse( + any(isinstance(t, Banana) for t in resolved_dependency.transportations) + ) + + def test_resolve_injects_all_tagged_dependencies(self): + # arrange + class Red: + pass + + class Green: + pass + + class Blue: + pass + + class White: + pass + + class NonWhite: + pass + + class Palette: + def __init__(self, white_colors: List[AllTagged[Red, Green, Blue]]): + self.white_colors = white_colors + + # Register instances with various tags + dependency_container = DependencyContainer.get_instance() + dependency_container.register_transient( + White, tags={Red, Green, Blue} + ) # Should be included + dependency_container.register_transient( + NonWhite, tags={Red, Green} + ) # Should NOT be included + dependency_container.register_transient(Palette) + + # act + resolved_dependency = dependency_container.resolve(Palette) + + # assert + self.assertIsInstance(resolved_dependency, Palette) + self.assertEqual(len(resolved_dependency.white_colors), 1) + self.assertIsInstance(resolved_dependency.white_colors[0], White) + self.assertNotIsInstance(resolved_dependency.white_colors[0], NonWhite) + + def test_resolve_injects_empty_list_if_no_tags_match(self): + # arrange + class PrimaryPort: + pass + + class Application: + def __init__(self, primary_ports: List[Tagged[PrimaryPort]]): + self.primary_ports = primary_ports + + dependency_container = DependencyContainer.get_instance() + dependency_container.register_transient(Application) + + # act + resolved_dependency = dependency_container.resolve(Application) + + # assert + self.assertIsInstance(resolved_dependency, Application) + self.assertEqual(len(resolved_dependency.primary_ports), 0)