Skip to content

feat(parent): Support strawberry.Parent with future annotations #3851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
Release type: minor

This release adds support to use `strawberry.Parent` with future annotations.

For example, the following code will now work as intended:

```python
from __future__ import annotations


def get_full_name(user: strawberry.Parent[User]) -> str:
return f"{user.first_name} {user.last_name}"


@strawberry.type
class User:
first_name: str
last_name: str
full_name: str = strawberry.field(resolver=get_full_name)


@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return User(first_name="John", last_name="Doe")


schema = strawberry.Schema(query=Query)
```

Or even when not using future annotations, but delaying the evaluation of `User`, like:


```python
# Note the User being delayed by passing it as a string
def get_full_name(user: strawberry.Parent["User"]) -> str:
return f"{user.first_name} {user.last_name}"


@strawberry.type
class User:
first_name: str
last_name: str
full_name: str = strawberry.field(resolver=get_full_name)


@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return User(first_name="John", last_name="Doe")


schema = strawberry.Schema(query=Query)
```
23 changes: 21 additions & 2 deletions strawberry/parent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Annotated, TypeVar
import re
from typing import Annotated, Any, ForwardRef, TypeVar

_parent_re = re.compile(r"^(?:strawberry\.)?Parent\[(.*)\]$")


class StrawberryParent: ...
Expand Down Expand Up @@ -40,4 +43,20 @@ def user(self) -> User:
```
"""

__all__ = ["Parent"]

def resolve_parent_forward_arg(annotation: Any) -> Any:
if isinstance(annotation, str):
str_annotation = annotation
elif isinstance(annotation, ForwardRef):
str_annotation = annotation.__forward_arg__
else:
# If neither, return the annotation as is
return annotation

if parent_match := _parent_re.match(str_annotation):
annotation = Parent[ForwardRef(parent_match.group(1))] # type: ignore[misc]

return annotation


__all__ = ["Parent", "StrawberryParent", "resolve_parent_forward_arg"]
17 changes: 12 additions & 5 deletions strawberry/types/fields/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ConflictingArgumentsError,
MissingArgumentsAnnotationsError,
)
from strawberry.parent import StrawberryParent
from strawberry.parent import StrawberryParent, resolve_parent_forward_arg
from strawberry.types.arguments import StrawberryArgument
from strawberry.types.base import StrawberryType, has_object_definition
from strawberry.types.info import Info
Expand Down Expand Up @@ -118,10 +118,17 @@ def find(
try:
evaled_annotation = annotation.evaluate()
except NameError:
continue
else:
if self.is_reserved_type(evaled_annotation):
type_parameters.append(parameter)
# If this is a strawberry.Parent using ForwardRef, we will fail to
# evaluate at this moment, but at least knowing that it is a reserved
# type is enough for now
# We might want to revisit this in the future, maybe by postponing
# this check to when the schema is actually being created
evaled_annotation = resolve_parent_forward_arg(
annotation.annotation
)

if self.is_reserved_type(evaled_annotation):
type_parameters.append(parameter)

if len(type_parameters) > 1:
raise ConflictingArgumentsError(
Expand Down
18 changes: 18 additions & 0 deletions strawberry/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,13 @@ def eval_type(
localns: Optional[dict] = None,
) -> type:
"""Evaluates a type, resolving forward references."""
from strawberry.parent import StrawberryParent
from strawberry.types.auto import StrawberryAuto
from strawberry.types.lazy_type import StrawberryLazyReference
from strawberry.types.private import StrawberryPrivate

globalns = globalns or {}

# If this is not a string, maybe its args are (e.g. list["Foo"])
if isinstance(type_, ForwardRef):
ast_obj = cast("ast.Expr", ast.parse(type_.__forward_arg__).body[0])
Expand Down Expand Up @@ -355,13 +357,29 @@ def eval_type(
)
args = (type_arg, *remaining_args)
break

if isinstance(arg, StrawberryAuto):
remaining_args = [
a for a in args[1:] if not isinstance(a, StrawberryAuto)
]
args = (args[0], arg, *remaining_args)
break

if isinstance(arg, StrawberryParent):
remaining_args = [
a for a in args[1:] if not isinstance(a, StrawberryParent)
]
try:
type_arg = (
eval_type(args[0], globalns, localns)
if isinstance(args[0], ForwardRef)
else args[0]
)
except (NameError, TypeError):
type_arg = args[0]
args = (type_arg, arg, *remaining_args)
break

# If we have only a StrawberryLazyReference and no more annotations,
# we need to return the argument directly because Annotated
# will raise an error if trying to instantiate it with only
Expand Down
70 changes: 70 additions & 0 deletions tests/types/test_parent_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import textwrap
from typing import ForwardRef

import pytest

import strawberry
from strawberry.parent import resolve_parent_forward_arg


def test_parent_type():
global User

try:

def get_full_name(user: strawberry.Parent["User"]) -> str:
return f"{user.first_name} {user.last_name}"

@strawberry.type
class User:
first_name: str
last_name: str
full_name: str = strawberry.field(resolver=get_full_name)

@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return User(first_name="John", last_name="Doe") # noqa: F821

schema = strawberry.Schema(query=Query)

expected = """\
type Query {
user: User!
}

type User {
firstName: String!
lastName: String!
fullName: String!
}
"""
assert textwrap.dedent(str(schema)).strip() == textwrap.dedent(expected).strip()

query = "{ user { firstName, lastName, fullName } }"
result = schema.execute_sync(query)
assert not result.errors
assert result.data == {
"user": {
"firstName": "John",
"lastName": "Doe",
"fullName": "John Doe",
}
}
finally:
del User


@pytest.mark.parametrize(
("annotation", "expected"),
[
("strawberry.Parent[str]", strawberry.Parent["str"]),
("Parent[str]", strawberry.Parent["str"]),
(ForwardRef("strawberry.Parent[str]"), strawberry.Parent["str"]),
(ForwardRef("Parent[str]"), strawberry.Parent["str"]),
(strawberry.Parent["User"], strawberry.Parent["User"]),
],
)
def test_resolve_parent_forward_arg(annotation, expected):
assert resolve_parent_forward_arg(annotation) == expected
54 changes: 54 additions & 0 deletions tests/types/test_parent_type_future_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

import textwrap

import strawberry


def test_parent_type():
global User

try:

def get_full_name(user: strawberry.Parent[User]) -> str:
return f"{user.first_name} {user.last_name}"

@strawberry.type
class User:
first_name: str
last_name: str
full_name: str = strawberry.field(resolver=get_full_name)

@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return User(first_name="John", last_name="Doe") # noqa: F821

schema = strawberry.Schema(query=Query)

expected = """\
type Query {
user: User!
}

type User {
firstName: String!
lastName: String!
fullName: String!
}
"""
assert textwrap.dedent(str(schema)).strip() == textwrap.dedent(expected).strip()

query = "{ user { firstName, lastName, fullName } }"
result = schema.execute_sync(query)
assert not result.errors
assert result.data == {
"user": {
"firstName": "John",
"lastName": "Doe",
"fullName": "John Doe",
}
}
finally:
del User
Loading