diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..48f8dbd0b3 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,56 @@ +Release type: minor + +This release adds support to use `strawberry.Parent` with future annotations. + +For example, the following code will now work as intended: + +```python +from __future__ import annotations + + +def get_full_name(user: strawberry.Parent[User]) -> str: + return f"{user.first_name} {user.last_name}" + + +@strawberry.type +class User: + first_name: str + last_name: str + full_name: str = strawberry.field(resolver=get_full_name) + + +@strawberry.type +class Query: + @strawberry.field + def user(self) -> User: + return User(first_name="John", last_name="Doe") + + +schema = strawberry.Schema(query=Query) +``` + +Or even when not using future annotations, but delaying the evaluation of `User`, like: + + +```python +# Note the User being delayed by passing it as a string +def get_full_name(user: strawberry.Parent["User"]) -> str: + return f"{user.first_name} {user.last_name}" + + +@strawberry.type +class User: + first_name: str + last_name: str + full_name: str = strawberry.field(resolver=get_full_name) + + +@strawberry.type +class Query: + @strawberry.field + def user(self) -> User: + return User(first_name="John", last_name="Doe") + + +schema = strawberry.Schema(query=Query) +``` diff --git a/strawberry/parent.py b/strawberry/parent.py index 99223028ba..d43cb8ffc6 100644 --- a/strawberry/parent.py +++ b/strawberry/parent.py @@ -1,4 +1,7 @@ -from typing import Annotated, TypeVar +import re +from typing import Annotated, Any, ForwardRef, TypeVar + +_parent_re = re.compile(r"^(?:strawberry\.)?Parent\[(.*)\]$") class StrawberryParent: ... @@ -40,4 +43,20 @@ def user(self) -> User: ``` """ -__all__ = ["Parent"] + +def resolve_parent_forward_arg(annotation: Any) -> Any: + if isinstance(annotation, str): + str_annotation = annotation + elif isinstance(annotation, ForwardRef): + str_annotation = annotation.__forward_arg__ + else: + # If neither, return the annotation as is + return annotation + + if parent_match := _parent_re.match(str_annotation): + annotation = Parent[ForwardRef(parent_match.group(1))] # type: ignore[misc] + + return annotation + + +__all__ = ["Parent", "StrawberryParent", "resolve_parent_forward_arg"] diff --git a/strawberry/types/fields/resolver.py b/strawberry/types/fields/resolver.py index 45e6ad710d..63d27d8502 100644 --- a/strawberry/types/fields/resolver.py +++ b/strawberry/types/fields/resolver.py @@ -25,7 +25,7 @@ ConflictingArgumentsError, MissingArgumentsAnnotationsError, ) -from strawberry.parent import StrawberryParent +from strawberry.parent import StrawberryParent, resolve_parent_forward_arg from strawberry.types.arguments import StrawberryArgument from strawberry.types.base import StrawberryType, has_object_definition from strawberry.types.info import Info @@ -118,10 +118,17 @@ def find( try: evaled_annotation = annotation.evaluate() except NameError: - continue - else: - if self.is_reserved_type(evaled_annotation): - type_parameters.append(parameter) + # If this is a strawberry.Parent using ForwardRef, we will fail to + # evaluate at this moment, but at least knowing that it is a reserved + # type is enough for now + # We might want to revisit this in the future, maybe by postponing + # this check to when the schema is actually being created + evaled_annotation = resolve_parent_forward_arg( + annotation.annotation + ) + + if self.is_reserved_type(evaled_annotation): + type_parameters.append(parameter) if len(type_parameters) > 1: raise ConflictingArgumentsError( diff --git a/strawberry/utils/typing.py b/strawberry/utils/typing.py index 1db9ef9dd0..54a7efaf69 100644 --- a/strawberry/utils/typing.py +++ b/strawberry/utils/typing.py @@ -304,11 +304,13 @@ def eval_type( localns: Optional[dict] = None, ) -> type: """Evaluates a type, resolving forward references.""" + from strawberry.parent import StrawberryParent from strawberry.types.auto import StrawberryAuto from strawberry.types.lazy_type import StrawberryLazyReference from strawberry.types.private import StrawberryPrivate globalns = globalns or {} + # If this is not a string, maybe its args are (e.g. list["Foo"]) if isinstance(type_, ForwardRef): ast_obj = cast("ast.Expr", ast.parse(type_.__forward_arg__).body[0]) @@ -355,6 +357,7 @@ def eval_type( ) args = (type_arg, *remaining_args) break + if isinstance(arg, StrawberryAuto): remaining_args = [ a for a in args[1:] if not isinstance(a, StrawberryAuto) @@ -362,6 +365,21 @@ def eval_type( args = (args[0], arg, *remaining_args) break + if isinstance(arg, StrawberryParent): + remaining_args = [ + a for a in args[1:] if not isinstance(a, StrawberryParent) + ] + try: + type_arg = ( + eval_type(args[0], globalns, localns) + if isinstance(args[0], ForwardRef) + else args[0] + ) + except (NameError, TypeError): + type_arg = args[0] + args = (type_arg, arg, *remaining_args) + break + # If we have only a StrawberryLazyReference and no more annotations, # we need to return the argument directly because Annotated # will raise an error if trying to instantiate it with only diff --git a/tests/types/test_parent_type.py b/tests/types/test_parent_type.py new file mode 100644 index 0000000000..7b2cf2128a --- /dev/null +++ b/tests/types/test_parent_type.py @@ -0,0 +1,70 @@ +import textwrap +from typing import ForwardRef + +import pytest + +import strawberry +from strawberry.parent import resolve_parent_forward_arg + + +def test_parent_type(): + global User + + try: + + def get_full_name(user: strawberry.Parent["User"]) -> str: + return f"{user.first_name} {user.last_name}" + + @strawberry.type + class User: + first_name: str + last_name: str + full_name: str = strawberry.field(resolver=get_full_name) + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(first_name="John", last_name="Doe") # noqa: F821 + + schema = strawberry.Schema(query=Query) + + expected = """\ + type Query { + user: User! + } + + type User { + firstName: String! + lastName: String! + fullName: String! + } + """ + assert textwrap.dedent(str(schema)).strip() == textwrap.dedent(expected).strip() + + query = "{ user { firstName, lastName, fullName } }" + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "user": { + "firstName": "John", + "lastName": "Doe", + "fullName": "John Doe", + } + } + finally: + del User + + +@pytest.mark.parametrize( + ("annotation", "expected"), + [ + ("strawberry.Parent[str]", strawberry.Parent["str"]), + ("Parent[str]", strawberry.Parent["str"]), + (ForwardRef("strawberry.Parent[str]"), strawberry.Parent["str"]), + (ForwardRef("Parent[str]"), strawberry.Parent["str"]), + (strawberry.Parent["User"], strawberry.Parent["User"]), + ], +) +def test_resolve_parent_forward_arg(annotation, expected): + assert resolve_parent_forward_arg(annotation) == expected diff --git a/tests/types/test_parent_type_future_annotations.py b/tests/types/test_parent_type_future_annotations.py new file mode 100644 index 0000000000..d7bd3c62bb --- /dev/null +++ b/tests/types/test_parent_type_future_annotations.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import textwrap + +import strawberry + + +def test_parent_type(): + global User + + try: + + def get_full_name(user: strawberry.Parent[User]) -> str: + return f"{user.first_name} {user.last_name}" + + @strawberry.type + class User: + first_name: str + last_name: str + full_name: str = strawberry.field(resolver=get_full_name) + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(first_name="John", last_name="Doe") # noqa: F821 + + schema = strawberry.Schema(query=Query) + + expected = """\ + type Query { + user: User! + } + + type User { + firstName: String! + lastName: String! + fullName: String! + } + """ + assert textwrap.dedent(str(schema)).strip() == textwrap.dedent(expected).strip() + + query = "{ user { firstName, lastName, fullName } }" + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "user": { + "firstName": "John", + "lastName": "Doe", + "fullName": "John Doe", + } + } + finally: + del User