1
1
import typing as t
2
2
from pathlib import Path
3
+ from uuid import uuid4
3
4
4
5
from starlette .testclient import TestClient as TestClient
5
6
7
+ from ellar .common import Module
6
8
from ellar .core import ModuleBase
7
9
from ellar .core .factory import AppFactory
8
10
from ellar .core .main import App
9
11
from ellar .core .routing import ModuleRouter
10
12
from ellar .di import ProviderConfig
13
+ from ellar .types import T
11
14
12
15
if t .TYPE_CHECKING : # pragma: no cover
13
16
from ellar .core import GuardCanActivate
14
17
15
18
16
- class _TestingModule :
17
- def __init__ (self , app : App ) -> None :
18
- self .app = app
19
+ class TestingModule :
20
+ def __init__ (
21
+ self ,
22
+ testing_module : t .Type [t .Union [ModuleBase , t .Any ]],
23
+ global_guards : t .List [
24
+ t .Union [t .Type ["GuardCanActivate" ], "GuardCanActivate" ]
25
+ ] = None ,
26
+ config_module : t .Union [str , t .Dict ] = None ,
27
+ ) -> None :
28
+ self ._testing_module = testing_module
29
+ self ._config_module = config_module
30
+ self ._global_guards = global_guards
31
+ self ._providers : t .List [ProviderConfig ] = []
32
+ self ._app : t .Optional [App ] = None
33
+
34
+ def override_provider (
35
+ self ,
36
+ base_type : t .Union [t .Type [T ], t .Type ],
37
+ * ,
38
+ use_value : T = None ,
39
+ use_class : t .Union [t .Type [T ], t .Any ] = None ,
40
+ ) -> "TestingModule" :
41
+ provider_config = ProviderConfig (
42
+ base_type , use_class = use_class , use_value = use_value
43
+ )
44
+ self ._providers .append (provider_config )
45
+ return self
46
+
47
+ def create_application (self ) -> App :
48
+ if self ._app :
49
+ return self ._app
50
+ self ._app = AppFactory .create_app (
51
+ modules = [self ._testing_module ],
52
+ global_guards = self ._global_guards ,
53
+ config_module = self ._config_module ,
54
+ providers = self ._providers ,
55
+ )
56
+ return self ._app
19
57
20
- def get_client (
58
+ def get_test_client (
21
59
self ,
22
60
base_url : str = "http://testserver" ,
23
61
raise_server_exceptions : bool = True ,
@@ -26,20 +64,25 @@ def get_client(
26
64
backend_options : t .Optional [t .Dict [str , t .Any ]] = None ,
27
65
) -> TestClient :
28
66
return TestClient (
29
- app = self .app ,
67
+ app = self .create_application () ,
30
68
base_url = base_url ,
31
69
raise_server_exceptions = raise_server_exceptions ,
32
70
backend = backend ,
33
71
backend_options = backend_options ,
34
72
root_path = root_path ,
35
73
)
36
74
75
+ def get (self , interface : t .Type [T ]) -> T :
76
+ return self .create_application ().injector .get (interface ) # type: ignore[no-any-return]
77
+
78
+
79
+ class Test :
80
+ TESTING_MODULE = TestingModule
37
81
38
- class TestClientFactory :
39
82
@classmethod
40
83
def create_test_module (
41
84
cls ,
42
- modules : t .Sequence [t .Union [t .Type , t .Any ]] = tuple (),
85
+ modules : t .Sequence [t .Type [t .Union [ ModuleBase , t .Any ] ]] = tuple (),
43
86
controllers : t .Sequence [t .Union [t .Any ]] = tuple (),
44
87
routers : t .Sequence [ModuleRouter ] = tuple (),
45
88
providers : t .Sequence [ProviderConfig ] = tuple (),
@@ -50,7 +93,7 @@ def create_test_module(
50
93
t .Union [t .Type ["GuardCanActivate" ], "GuardCanActivate" ]
51
94
] = None ,
52
95
config_module : t .Union [str , t .Dict ] = None ,
53
- ) -> _TestingModule :
96
+ ) -> TestingModule :
54
97
"""
55
98
Create a TestingModule to test controllers and services in isolation
56
99
:param modules:
@@ -64,30 +107,28 @@ def create_test_module(
64
107
:param global_guards:
65
108
:return:
66
109
"""
67
- app = AppFactory . create_app (
110
+ module = Module (
68
111
modules = modules ,
69
112
controllers = controllers ,
70
113
routers = routers ,
71
114
providers = providers ,
72
115
template_folder = template_folder ,
73
116
base_directory = base_directory ,
74
117
static_folder = static_folder ,
75
- config_module = config_module ,
76
- global_guards = global_guards ,
77
118
)
78
- return _TestingModule (app = app )
119
+ testing_module = type (f"TestingModule_{ uuid4 ().hex [:6 ]} " , (ModuleBase ,), {})
120
+ module (testing_module )
121
+ return cls .TESTING_MODULE (
122
+ testing_module , global_guards = global_guards , config_module = config_module
123
+ )
79
124
80
125
@classmethod
81
126
def create_test_module_from_module (
82
127
cls ,
83
- module : t .Union [t .Type , t .Type [ModuleBase ]],
84
- mock_providers : t .Sequence [ProviderConfig ] = tuple (),
128
+ module : t .Type [t .Union [ModuleBase , t .Any ]],
85
129
config_module : str = None ,
86
- ) -> _TestingModule :
130
+ ) -> TestingModule :
87
131
"""
88
132
Create a TestingModule from an existing module
89
133
"""
90
- app = AppFactory .create_app (
91
- modules = (module ,), providers = mock_providers , config_module = config_module
92
- )
93
- return _TestingModule (app = app )
134
+ return cls .TESTING_MODULE (module , global_guards = [], config_module = config_module )
0 commit comments