Skip to content

Commit e75995b

Browse files
committed
feat(parent): Support strawberry.Parent with future annotations
Fix #3481
1 parent f6ac9bb commit e75995b

File tree

7 files changed

+223
-10
lines changed

7 files changed

+223
-10
lines changed

RELEASE.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
Release type: minor
2+
3+
This release adds support to use `strawberry.Parent` with future annotations.
4+
5+
For example, the following code will now work as intended:
6+
7+
```python
8+
from __future__ import annotations
9+
10+
11+
def get_full_name(user: strawberry.Parent[User]) -> str:
12+
return f"{user.first_name} {user.last_name}"
13+
14+
15+
@strawberry.type
16+
class User:
17+
first_name: str
18+
last_name: str
19+
full_name: str = strawberry.field(resolver=get_full_name)
20+
21+
22+
@strawberry.type
23+
class Query:
24+
@strawberry.field
25+
def user(self) -> User:
26+
return User(first_name="John", last_name="Doe")
27+
28+
29+
schema = strawberry.Schema(query=Query)
30+
```
31+
32+
Or even when not using future annotations, but delaying the evaluation of `User`, like:
33+
34+
35+
```python
36+
# Note the User being delayed by passing it as a string
37+
def get_full_name(user: strawberry.Parent["User"]) -> str:
38+
return f"{user.first_name} {user.last_name}"
39+
40+
41+
@strawberry.type
42+
class User:
43+
first_name: str
44+
last_name: str
45+
full_name: str = strawberry.field(resolver=get_full_name)
46+
47+
48+
@strawberry.type
49+
class Query:
50+
@strawberry.field
51+
def user(self) -> User:
52+
return User(first_name="John", last_name="Doe")
53+
54+
55+
schema = strawberry.Schema(query=Query)
56+
```

strawberry/parent.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from typing import Annotated, TypeVar
1+
import re
2+
from typing import Annotated, Any, ForwardRef, TypeVar, Union
3+
4+
_parent_re = re.compile(r"^(?:strawberry\.)?Parent\[(.*)\]$")
25

36

47
class StrawberryParent: ...
@@ -40,4 +43,20 @@ def user(self) -> User:
4043
```
4144
"""
4245

43-
__all__ = ["Parent"]
46+
47+
def resolve_parent_forward_arg(annotation: Union[str, ForwardRef]) -> Any:
48+
if isinstance(annotation, str):
49+
str_annotation = annotation
50+
elif isinstance(annotation, ForwardRef):
51+
str_annotation = annotation.__forward_arg__
52+
else:
53+
# If neither, return the annotation as is
54+
return annotation
55+
56+
if parent_match := _parent_re.match(str_annotation):
57+
annotation = Parent[ForwardRef(parent_match.group(1))]
58+
59+
return annotation
60+
61+
62+
__all__ = ["Parent", "StrawberryParent", "resolve_parent_forward_arg"]

strawberry/types/field.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,22 @@ def __call__(self, resolver: _RESOLVER_TYPE) -> Self:
192192
if isinstance(argument.type_annotation.annotation, str):
193193
continue
194194

195-
if isinstance(argument.type, StrawberryUnion):
195+
try:
196+
argument_type = argument.type
197+
except NameError:
198+
# This is a forward reference which we can't resolve yet,
199+
# ok to skip for now...
200+
continue
201+
202+
if isinstance(argument_type, StrawberryUnion):
196203
raise InvalidArgumentTypeError(
197204
resolver,
198205
argument,
199206
)
200207

201208
if (
202-
has_object_definition(argument.type)
203-
and argument.type.__strawberry_definition__.is_interface
209+
has_object_definition(argument_type)
210+
and argument_type.__strawberry_definition__.is_interface
204211
):
205212
raise InvalidArgumentTypeError(
206213
resolver,

strawberry/types/fields/resolver.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ConflictingArgumentsError,
2626
MissingArgumentsAnnotationsError,
2727
)
28-
from strawberry.parent import StrawberryParent
28+
from strawberry.parent import StrawberryParent, resolve_parent_forward_arg
2929
from strawberry.types.arguments import StrawberryArgument
3030
from strawberry.types.base import StrawberryType, has_object_definition
3131
from strawberry.types.info import Info
@@ -118,10 +118,17 @@ def find(
118118
try:
119119
evaled_annotation = annotation.evaluate()
120120
except NameError:
121-
continue
122-
else:
123-
if self.is_reserved_type(evaled_annotation):
124-
type_parameters.append(parameter)
121+
# If this is a strabwerry.Parent using ForwardRef, we will fail to
122+
# evaluate at this moment, but at least knowing that it is a reserved
123+
# type is enough for now
124+
# We might want to revisit this in the future, maybe by postponing
125+
# this check to when the schema is actually being created
126+
evaled_annotation = resolve_parent_forward_arg(
127+
annotation.annotation
128+
)
129+
130+
if self.is_reserved_type(evaled_annotation):
131+
type_parameters.append(parameter)
125132

126133
if len(type_parameters) > 1:
127134
raise ConflictingArgumentsError(

strawberry/utils/typing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,13 @@ def eval_type(
304304
localns: Optional[dict] = None,
305305
) -> type:
306306
"""Evaluates a type, resolving forward references."""
307+
from strawberry.parent import StrawberryParent
307308
from strawberry.types.auto import StrawberryAuto
308309
from strawberry.types.lazy_type import StrawberryLazyReference
309310
from strawberry.types.private import StrawberryPrivate
310311

311312
globalns = globalns or {}
313+
312314
# If this is not a string, maybe its args are (e.g. list["Foo"])
313315
if isinstance(type_, ForwardRef):
314316
ast_obj = cast("ast.Expr", ast.parse(type_.__forward_arg__).body[0])
@@ -355,13 +357,29 @@ def eval_type(
355357
)
356358
args = (type_arg, *remaining_args)
357359
break
360+
358361
if isinstance(arg, StrawberryAuto):
359362
remaining_args = [
360363
a for a in args[1:] if not isinstance(a, StrawberryAuto)
361364
]
362365
args = (args[0], arg, *remaining_args)
363366
break
364367

368+
if isinstance(arg, StrawberryParent):
369+
remaining_args = [
370+
a for a in args[1:] if not isinstance(a, StrawberryParent)
371+
]
372+
try:
373+
type_arg = (
374+
eval_type(args[0], globalns, localns)
375+
if isinstance(args[0], ForwardRef)
376+
else args[0]
377+
)
378+
except (NameError, TypeError):
379+
type_arg = args[0]
380+
args = (type_arg, arg, *remaining_args)
381+
break
382+
365383
# If we have only a StrawberryLazyReference and no more annotations,
366384
# we need to return the argument directly because Annotated
367385
# will raise an error if trying to instantiate it with only

tests/types/test_parent_type.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import textwrap
2+
3+
import strawberry
4+
5+
6+
def test_parent_type():
7+
global User
8+
9+
try:
10+
11+
def get_full_name(user: strawberry.Parent["User"]) -> str:
12+
return f"{user.first_name} {user.last_name}"
13+
14+
@strawberry.type
15+
class User:
16+
first_name: str
17+
last_name: str
18+
full_name: str = strawberry.field(resolver=get_full_name)
19+
20+
@strawberry.type
21+
class Query:
22+
@strawberry.field
23+
def user(self) -> User:
24+
return User(first_name="John", last_name="Doe") # noqa: F821
25+
26+
schema = strawberry.Schema(query=Query)
27+
28+
expected = """\
29+
type Query {
30+
user: User!
31+
}
32+
33+
type User {
34+
firstName: String!
35+
lastName: String!
36+
fullName: String!
37+
}
38+
"""
39+
assert textwrap.dedent(str(schema)).strip() == textwrap.dedent(expected).strip()
40+
41+
query = "{ user { firstName, lastName, fullName } }"
42+
result = schema.execute_sync(query)
43+
assert not result.errors
44+
assert result.data == {
45+
"user": {
46+
"firstName": "John",
47+
"lastName": "Doe",
48+
"fullName": "John Doe",
49+
}
50+
}
51+
finally:
52+
del User
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from __future__ import annotations
2+
3+
import textwrap
4+
5+
import strawberry
6+
7+
8+
def test_parent_type():
9+
global User
10+
11+
try:
12+
13+
def get_full_name(user: strawberry.Parent[User]) -> str:
14+
return f"{user.first_name} {user.last_name}"
15+
16+
@strawberry.type
17+
class User:
18+
first_name: str
19+
last_name: str
20+
full_name: str = strawberry.field(resolver=get_full_name)
21+
22+
@strawberry.type
23+
class Query:
24+
@strawberry.field
25+
def user(self) -> User:
26+
return User(first_name="John", last_name="Doe") # noqa: F821
27+
28+
schema = strawberry.Schema(query=Query)
29+
30+
expected = """\
31+
type Query {
32+
user: User!
33+
}
34+
35+
type User {
36+
firstName: String!
37+
lastName: String!
38+
fullName: String!
39+
}
40+
"""
41+
assert textwrap.dedent(str(schema)).strip() == textwrap.dedent(expected).strip()
42+
43+
query = "{ user { firstName, lastName, fullName } }"
44+
result = schema.execute_sync(query)
45+
assert not result.errors
46+
assert result.data == {
47+
"user": {
48+
"firstName": "John",
49+
"lastName": "Doe",
50+
"fullName": "John Doe",
51+
}
52+
}
53+
finally:
54+
del User

0 commit comments

Comments
 (0)