Skip to content

Commit e857ce4

Browse files
committed
Moved TestClientFactory from core to testing package and refactored testing module
1 parent a3dfe25 commit e857ce4

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

ellar/testing/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from starlette.testclient import TestClient as TestClient
2+
3+
from .module import Test
4+
5+
__all__ = [
6+
"Test",
7+
"TestClient",
8+
]
Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,61 @@
11
import typing as t
22
from pathlib import Path
3+
from uuid import uuid4
34

45
from starlette.testclient import TestClient as TestClient
56

7+
from ellar.common import Module
68
from ellar.core import ModuleBase
79
from ellar.core.factory import AppFactory
810
from ellar.core.main import App
911
from ellar.core.routing import ModuleRouter
1012
from ellar.di import ProviderConfig
13+
from ellar.types import T
1114

1215
if t.TYPE_CHECKING: # pragma: no cover
1316
from ellar.core import GuardCanActivate
1417

1518

16-
class _TestingModule:
17-
def __init__(self, app: App) -> None:
18-
self.app = app
19+
class TestingModule:
20+
def __init__(
21+
self,
22+
testing_module: t.Type[t.Union[ModuleBase, t.Any]],
23+
global_guards: t.List[
24+
t.Union[t.Type["GuardCanActivate"], "GuardCanActivate"]
25+
] = None,
26+
config_module: t.Union[str, t.Dict] = None,
27+
) -> None:
28+
self._testing_module = testing_module
29+
self._config_module = config_module
30+
self._global_guards = global_guards
31+
self._providers: t.List[ProviderConfig] = []
32+
self._app: t.Optional[App] = None
33+
34+
def override_provider(
35+
self,
36+
base_type: t.Union[t.Type[T], t.Type],
37+
*,
38+
use_value: T = None,
39+
use_class: t.Union[t.Type[T], t.Any] = None,
40+
) -> "TestingModule":
41+
provider_config = ProviderConfig(
42+
base_type, use_class=use_class, use_value=use_value
43+
)
44+
self._providers.append(provider_config)
45+
return self
46+
47+
def create_application(self) -> App:
48+
if self._app:
49+
return self._app
50+
self._app = AppFactory.create_app(
51+
modules=[self._testing_module],
52+
global_guards=self._global_guards,
53+
config_module=self._config_module,
54+
providers=self._providers,
55+
)
56+
return self._app
1957

20-
def get_client(
58+
def get_test_client(
2159
self,
2260
base_url: str = "http://testserver",
2361
raise_server_exceptions: bool = True,
@@ -26,20 +64,25 @@ def get_client(
2664
backend_options: t.Optional[t.Dict[str, t.Any]] = None,
2765
) -> TestClient:
2866
return TestClient(
29-
app=self.app,
67+
app=self.create_application(),
3068
base_url=base_url,
3169
raise_server_exceptions=raise_server_exceptions,
3270
backend=backend,
3371
backend_options=backend_options,
3472
root_path=root_path,
3573
)
3674

75+
def get(self, interface: t.Type[T]) -> T:
76+
return self.create_application().injector.get(interface) # type: ignore[no-any-return]
77+
78+
79+
class Test:
80+
TESTING_MODULE = TestingModule
3781

38-
class TestClientFactory:
3982
@classmethod
4083
def create_test_module(
4184
cls,
42-
modules: t.Sequence[t.Union[t.Type, t.Any]] = tuple(),
85+
modules: t.Sequence[t.Type[t.Union[ModuleBase, t.Any]]] = tuple(),
4386
controllers: t.Sequence[t.Union[t.Any]] = tuple(),
4487
routers: t.Sequence[ModuleRouter] = tuple(),
4588
providers: t.Sequence[ProviderConfig] = tuple(),
@@ -50,7 +93,7 @@ def create_test_module(
5093
t.Union[t.Type["GuardCanActivate"], "GuardCanActivate"]
5194
] = None,
5295
config_module: t.Union[str, t.Dict] = None,
53-
) -> _TestingModule:
96+
) -> TestingModule:
5497
"""
5598
Create a TestingModule to test controllers and services in isolation
5699
:param modules:
@@ -64,30 +107,28 @@ def create_test_module(
64107
:param global_guards:
65108
:return:
66109
"""
67-
app = AppFactory.create_app(
110+
module = Module(
68111
modules=modules,
69112
controllers=controllers,
70113
routers=routers,
71114
providers=providers,
72115
template_folder=template_folder,
73116
base_directory=base_directory,
74117
static_folder=static_folder,
75-
config_module=config_module,
76-
global_guards=global_guards,
77118
)
78-
return _TestingModule(app=app)
119+
testing_module = type(f"TestingModule_{uuid4().hex[:6]}", (ModuleBase,), {})
120+
module(testing_module)
121+
return cls.TESTING_MODULE(
122+
testing_module, global_guards=global_guards, config_module=config_module
123+
)
79124

80125
@classmethod
81126
def create_test_module_from_module(
82127
cls,
83-
module: t.Union[t.Type, t.Type[ModuleBase]],
84-
mock_providers: t.Sequence[ProviderConfig] = tuple(),
128+
module: t.Type[t.Union[ModuleBase, t.Any]],
85129
config_module: str = None,
86-
) -> _TestingModule:
130+
) -> TestingModule:
87131
"""
88132
Create a TestingModule from an existing module
89133
"""
90-
app = AppFactory.create_app(
91-
modules=(module,), providers=mock_providers, config_module=config_module
92-
)
93-
return _TestingModule(app=app)
134+
return cls.TESTING_MODULE(module, global_guards=[], config_module=config_module)

0 commit comments

Comments
 (0)