Skip to content

Commit ecfff4a

Browse files
committed
Added a separate QueryHeaderArgsResolverGenerator to manage args generation for header and query params.
1 parent b71e726 commit ecfff4a

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

ellar/core/params/args/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,16 @@
3636
BulkArgsResolverGenerator,
3737
FormArgsResolverGenerator,
3838
PathArgsResolverGenerator,
39+
QueryHeaderResolverGenerator,
3940
)
4041

4142

4243
class EndpointArgsModel:
4344
_bulk_resolvers_generators = {
4445
str(params.Form): FormArgsResolverGenerator,
4546
str(params.Path): PathArgsResolverGenerator,
47+
str(params.Query): QueryHeaderResolverGenerator,
48+
str(params.Header): QueryHeaderResolverGenerator,
4649
}
4750

4851
_provider_skip = primitive_types + sequence_types + (Representation,)

ellar/core/params/args/resolver_generators.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,42 @@
1616

1717

1818
class BulkArgsResolverGenerator:
19+
"""
20+
This class splits Schema into different ModelFields to he resolved independently and computed back later.
21+
class ASchema(BaseModel):
22+
A: int
23+
B: int
24+
25+
def endpoint(a: ASchema = Query())
26+
pass
27+
28+
args_resolver = BulkArgsResolverGenerator(a) where `a` is `endpoint` parameter.
29+
args_resolver.generate_resolvers() == [AModelFieldResolver, BModelFieldResolver]
30+
31+
The generated ModelFieldResolvers are saved to self.param_field.field_info.extra with MULTI_RESOLVER_KEY key
32+
Which will be available when creating a resolver and cleared afterwards
33+
34+
def create_resolver(self, model_field: ModelField) -> RouteParameterResolver:
35+
multiple_resolvers = model_field.field_info.extra.get(MULTI_RESOLVER_KEY)
36+
if multiple_resolvers:
37+
model_field.field_info.extra.clear()
38+
return self.bulk_resolver(
39+
model_field=model_field, resolvers=multiple_resolvers
40+
)
41+
return self.resolver(model_field)
42+
"""
43+
1944
__slots__ = ("param_field", "pydantic_outer_type")
2045

2146
def __init__(self, pydantic_type: ModelField) -> None:
2247
self.pydantic_outer_type = t.cast(BaseModel, pydantic_type.outer_type_)
2348
self.param_field = pydantic_type
2449

2550
def validate(self, field_name: str, field: ModelField) -> None:
26-
if not is_scalar_field(field=field) and not is_scalar_sequence_field(
27-
self.param_field
28-
):
51+
if not is_scalar_field(field=field):
2952
raise ImproperConfiguration(
30-
f"field: '{field_name}' with annotation:'{field.type_}' "
31-
f"can't be processed. Field type should belong to {sequence_types} "
32-
f"or any primitive type"
53+
f"field: '{field_name}' with annotation:'{field.outer_type_}' in '{self.param_field.type_}'"
54+
f"can't be processed. Field type is not a primitive type"
3355
)
3456

3557
def get_parameter_field(
@@ -91,21 +113,28 @@ def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
91113
self.param_field.field_info.extra[MULTI_RESOLVER_KEY] = resolvers
92114

93115

94-
class FormArgsResolverGenerator(BulkArgsResolverGenerator):
116+
class QueryHeaderResolverGenerator(BulkArgsResolverGenerator):
95117
def validate(self, field_name: str, field: ModelField) -> None:
96-
""" "Do nothing"""
118+
if not is_scalar_field(field=field) and not is_scalar_sequence_field(field):
119+
raise ImproperConfiguration(
120+
f"field: '{field_name}' with annotation:'{field.outer_type_}' in '{self.param_field.type_}'"
121+
f"can't be processed. Field type should belong to {sequence_types} "
122+
f"or any primitive type"
123+
)
124+
97125

126+
class FormArgsResolverGenerator(QueryHeaderResolverGenerator):
98127
def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
99128
super().generate_resolvers(body_field_class=body_field_class)
100129
self.param_field.field_info.extra[MULTI_RESOLVER_FORM_GROUPED_KEY] = True
101130

102131

103132
class PathArgsResolverGenerator(BulkArgsResolverGenerator):
104133
def validate(self, field_name: str, field: ModelField) -> None:
105-
""" "Do nothing"""
106-
assert is_scalar_field(
107-
field=field
108-
), "Path params must be of one of the supported types"
134+
if not is_scalar_field(field=field):
135+
raise ImproperConfiguration(
136+
"Path params must be of one of the supported types. Only primitive types"
137+
)
109138

110139
def get_parameter_field(
111140
self,

0 commit comments

Comments
 (0)