@@ -50,6 +50,12 @@ class GenericMeta(type):
50
50
'wire' ,
51
51
'unwire' ,
52
52
'inject' ,
53
+ 'as_int' ,
54
+ 'as_float' ,
55
+ 'as_' ,
56
+ 'required' ,
57
+ 'invariant' ,
58
+ 'provided' ,
53
59
'Provide' ,
54
60
'Provider' ,
55
61
'Closing' ,
@@ -85,16 +91,23 @@ def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
85
91
86
92
class ProvidersMap :
87
93
94
+ CONTAINER_STRING_ID = '<container>'
95
+
88
96
def __init__ (self , container ):
89
97
self ._container = container
90
98
self ._map = self ._create_providers_map (
91
99
current_container = container ,
92
- original_container = container .declarative_parent ,
100
+ original_container = (
101
+ container .declarative_parent
102
+ if container .declarative_parent
103
+ else container
104
+ ),
93
105
)
94
106
95
107
def resolve_provider (
96
108
self ,
97
- provider : providers .Provider ,
109
+ provider : Union [providers .Provider , str ],
110
+ modifier : Optional ['Modifier' ] = None ,
98
111
) -> Optional [providers .Provider ]:
99
112
if isinstance (provider , providers .Delegate ):
100
113
return self ._resolve_delegate (provider )
@@ -109,14 +122,29 @@ def resolve_provider(
109
122
return self ._resolve_config_option (provider )
110
123
elif isinstance (provider , providers .TypedConfigurationOption ):
111
124
return self ._resolve_config_option (provider .option , as_ = provider .provides )
125
+ elif isinstance (provider , str ):
126
+ return self ._resolve_string_id (provider , modifier )
112
127
else :
113
128
return self ._resolve_provider (provider )
114
129
115
- def _resolve_delegate (
130
+ def _resolve_string_id (
116
131
self ,
117
- original : providers .Delegate ,
132
+ id : str ,
133
+ modifier : Optional ['Modifier' ] = None ,
118
134
) -> Optional [providers .Provider ]:
119
- return self ._resolve_provider (original .provides )
135
+ if id == self .CONTAINER_STRING_ID :
136
+ return self ._container .__self__
137
+
138
+ provider = self ._container
139
+ for segment in id .split ('.' ):
140
+ try :
141
+ provider = getattr (provider , segment )
142
+ except AttributeError :
143
+ return None
144
+
145
+ if modifier :
146
+ provider = modifier .modify (provider , providers_map = self )
147
+ return provider
120
148
121
149
def _resolve_provided_instance (
122
150
self ,
@@ -151,6 +179,12 @@ def _resolve_provided_instance(
151
179
152
180
return new
153
181
182
+ def _resolve_delegate (
183
+ self ,
184
+ original : providers .Delegate ,
185
+ ) -> Optional [providers .Provider ]:
186
+ return self ._resolve_provider (original .provides )
187
+
154
188
def _resolve_config_option (
155
189
self ,
156
190
original : providers .ConfigurationOption ,
@@ -184,7 +218,7 @@ def _resolve_provider(
184
218
try :
185
219
return self ._map [original ]
186
220
except KeyError :
187
- pass
221
+ return None
188
222
189
223
@classmethod
190
224
def _create_providers_map (
@@ -381,7 +415,7 @@ def _fetch_reference_injections(
381
415
382
416
def _bind_injections (fn : Callable [..., Any ], providers_map : ProvidersMap ) -> None :
383
417
for injection , marker in fn .__reference_injections__ .items ():
384
- provider = providers_map .resolve_provider (marker .provider )
418
+ provider = providers_map .resolve_provider (marker .provider , marker . modifier )
385
419
386
420
if provider is None :
387
421
continue
@@ -516,20 +550,161 @@ def _is_declarative_container(instance: Any) -> bool:
516
550
and getattr (instance , 'declarative_parent' , None ) is None )
517
551
518
552
553
+ class Modifier :
554
+
555
+ def modify (
556
+ self ,
557
+ provider : providers .ConfigurationOption ,
558
+ providers_map : ProvidersMap ,
559
+ ) -> providers .Provider :
560
+ ...
561
+
562
+
563
+ class TypeModifier (Modifier ):
564
+
565
+ def __init__ (self , type_ : Type ):
566
+ self .type_ = type_
567
+
568
+ def modify (
569
+ self ,
570
+ provider : providers .ConfigurationOption ,
571
+ providers_map : ProvidersMap ,
572
+ ) -> providers .Provider :
573
+ return provider .as_ (self .type_ )
574
+
575
+
576
+ def as_int () -> TypeModifier :
577
+ """Return int type modifier."""
578
+ return TypeModifier (int )
579
+
580
+
581
+ def as_float () -> TypeModifier :
582
+ """Return float type modifier."""
583
+ return TypeModifier (float )
584
+
585
+
586
+ def as_ (type_ : Type ) -> TypeModifier :
587
+ """Return custom type modifier."""
588
+ return TypeModifier (type_ )
589
+
590
+
591
+ class RequiredModifier (Modifier ):
592
+
593
+ def __init__ (self ):
594
+ self .type_modifier = None
595
+
596
+ def as_int (self ) -> 'RequiredModifier' :
597
+ self .type_modifier = TypeModifier (int )
598
+ return self
599
+
600
+ def as_float (self ) -> 'RequiredModifier' :
601
+ self .type_modifier = TypeModifier (float )
602
+ return self
603
+
604
+ def as_ (self , type_ : Type ) -> 'RequiredModifier' :
605
+ self .type_modifier = TypeModifier (type_ )
606
+ return self
607
+
608
+ def modify (
609
+ self ,
610
+ provider : providers .ConfigurationOption ,
611
+ providers_map : ProvidersMap ,
612
+ ) -> providers .Provider :
613
+ provider = provider .required ()
614
+ if self .type_modifier :
615
+ provider = provider .as_ (self .type_modifier .type_ )
616
+ return provider
617
+
618
+
619
+ def required () -> RequiredModifier :
620
+ """Return required modifier."""
621
+ return RequiredModifier ()
622
+
623
+
624
+ class InvariantModifier (Modifier ):
625
+
626
+ def __init__ (self , id : str ) -> None :
627
+ self .id = id
628
+
629
+ def modify (
630
+ self ,
631
+ provider : providers .ConfigurationOption ,
632
+ providers_map : ProvidersMap ,
633
+ ) -> providers .Provider :
634
+ invariant_segment = providers_map .resolve_provider (self .id )
635
+ return provider [invariant_segment ]
636
+
637
+
638
+ def invariant (id : str ) -> InvariantModifier :
639
+ """Return invariant modifier."""
640
+ return InvariantModifier (id )
641
+
642
+
643
+ class ProvidedInstance (Modifier ):
644
+
645
+ TYPE_ATTRIBUTE = 'attr'
646
+ TYPE_ITEM = 'item'
647
+ TYPE_CALL = 'call'
648
+
649
+ def __init__ (self ):
650
+ self .segments = []
651
+
652
+ def __getattr__ (self , item ):
653
+ self .segments .append ((self .TYPE_ATTRIBUTE , item ))
654
+ return self
655
+
656
+ def __getitem__ (self , item ):
657
+ self .segments .append ((self .TYPE_ITEM , item ))
658
+ return self
659
+
660
+ def call (self ):
661
+ self .segments .append ((self .TYPE_CALL , None ))
662
+ return self
663
+
664
+ def modify (
665
+ self ,
666
+ provider : providers .ConfigurationOption ,
667
+ providers_map : ProvidersMap ,
668
+ ) -> providers .Provider :
669
+ provider = provider .provided
670
+ for type_ , value in self .segments :
671
+ if type_ == ProvidedInstance .TYPE_ATTRIBUTE :
672
+ provider = getattr (provider , value )
673
+ elif type_ == ProvidedInstance .TYPE_ITEM :
674
+ provider = provider [value ]
675
+ elif type_ == ProvidedInstance .TYPE_CALL :
676
+ provider = provider .call ()
677
+ return provider
678
+
679
+
680
+ def provided () -> ProvidedInstance :
681
+ """Return provided instance modifier."""
682
+ return ProvidedInstance ()
683
+
684
+
519
685
class ClassGetItemMeta (GenericMeta ):
520
686
def __getitem__ (cls , item ):
521
687
# Spike for Python 3.6
688
+ if isinstance (item , tuple ):
689
+ return cls (* item )
522
690
return cls (item )
523
691
524
692
525
693
class _Marker (Generic [T ], metaclass = ClassGetItemMeta ):
526
694
527
- def __init__ (self , provider : Union [providers .Provider , Container ]) -> None :
695
+ def __init__ (
696
+ self ,
697
+ provider : Union [providers .Provider , Container , str ],
698
+ modifier : Optional [Modifier ] = None ,
699
+ ) -> None :
528
700
if _is_declarative_container (provider ):
529
701
provider = provider .__self__
530
- self .provider : providers .Provider = provider
702
+ self .provider = provider
703
+ self .modifier = modifier
531
704
532
705
def __class_getitem__ (cls , item ) -> T :
706
+ if isinstance (item , tuple ):
707
+ return cls (* item )
533
708
return cls (item )
534
709
535
710
def __call__ (self ) -> T :
0 commit comments