diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..f8bb92af65 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,8 @@ +Release type: minor + +## Deprecation Alert +- **Deprecated**: Replaced `_enum_definition` with `__strawberry_definition__`. The former will continue to work but will raise a deprecation warning. + +### Other Changes +- **Renamed**: Changed `EnumDefinition` to `StrawberryEnum` to standardize internal naming patterns. +- These updates improve naming consistency and address previously identified TODOs. diff --git a/strawberry/annotation.py b/strawberry/annotation.py index fea5d5cbb1..04ccffffe4 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -29,7 +29,7 @@ get_object_definition, has_object_definition, ) -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.enum import enum as strawberry_enum from strawberry.types.lazy_type import LazyType from strawberry.types.maybe import _annotation_is_maybe @@ -215,11 +215,11 @@ def create_concrete_type(self, evaled_type: type) -> type: return evaled_type.__strawberry_definition__.resolve_generic(evaled_type) raise ValueError(f"Not supported {evaled_type}") - def create_enum(self, evaled_type: Any) -> EnumDefinition: + def create_enum(self, evaled_type: Any) -> StrawberryEnum: try: - return evaled_type._enum_definition + return evaled_type.__strawberry_definition__ except AttributeError: - return strawberry_enum(evaled_type)._enum_definition + return strawberry_enum(evaled_type).__strawberry_definition__ def create_list(self, evaled_type: Any) -> StrawberryList: item_type, *_ = get_args(evaled_type) @@ -388,7 +388,7 @@ def _is_strawberry_type(cls, evaled_type: Any) -> bool: # Prevent import cycles from strawberry.types.union import StrawberryUnion - if isinstance(evaled_type, EnumDefinition): + if isinstance(evaled_type, StrawberryEnum): return True elif _is_input_type(evaled_type): # TODO: Replace with StrawberryInputObject return True diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 0c7226c91c..4985a63b96 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -42,7 +42,7 @@ get_object_definition, has_object_definition, ) -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.lazy_type import LazyType from strawberry.types.scalar import ScalarDefinition, ScalarWrapper from strawberry.types.union import StrawberryUnion @@ -543,7 +543,7 @@ def _get_field_type( if isinstance(field_type, ScalarDefinition): return self._collect_scalar(field_type, None) - if isinstance(field_type, EnumDefinition): + if isinstance(field_type, StrawberryEnum): return self._collect_enum(field_type) raise ValueError(f"Unsupported type: {field_type}") # pragma: no cover @@ -897,7 +897,7 @@ def _collect_scalar( return graphql_scalar - def _collect_enum(self, enum: EnumDefinition) -> GraphQLEnum: + def _collect_enum(self, enum: StrawberryEnum) -> GraphQLEnum: graphql_enum = GraphQLEnum( enum.name, [value.name for value in enum.values], diff --git a/strawberry/experimental/pydantic/conversion.py b/strawberry/experimental/pydantic/conversion.py index ed5cbd86ed..1fb7a20af3 100644 --- a/strawberry/experimental/pydantic/conversion.py +++ b/strawberry/experimental/pydantic/conversion.py @@ -9,7 +9,7 @@ StrawberryOptional, has_object_definition, ) -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.union import StrawberryUnion if TYPE_CHECKING: @@ -40,7 +40,7 @@ def _convert_from_pydantic_to_strawberry_type( return _convert_from_pydantic_to_strawberry_type( option_type, data_from_model=data, extra=extra ) - if isinstance(type_, EnumDefinition): + if isinstance(type_, StrawberryEnum): return data if isinstance(type_, StrawberryList): items = [] diff --git a/strawberry/federation/schema.py b/strawberry/federation/schema.py index 4bb2359fa3..99633b40cc 100644 --- a/strawberry/federation/schema.py +++ b/strawberry/federation/schema.py @@ -34,7 +34,7 @@ from strawberry.federation.schema_directives import ComposeDirective from strawberry.schema.config import StrawberryConfig from strawberry.schema_directive import StrawberrySchemaDirective - from strawberry.types.enum import EnumDefinition + from strawberry.types.enum import StrawberryEnum from strawberry.types.scalar import ScalarDefinition, ScalarWrapper @@ -372,7 +372,7 @@ def _has_federation_keys( definition: Union[ StrawberryObjectDefinition, "ScalarDefinition", - "EnumDefinition", + "StrawberryEnum", "StrawberryUnion", ], ) -> bool: diff --git a/strawberry/printer/printer.py b/strawberry/printer/printer.py index 74152b8cd8..387d8c21db 100644 --- a/strawberry/printer/printer.py +++ b/strawberry/printer/printer.py @@ -36,7 +36,7 @@ StrawberryObjectDefinition, has_object_definition, ) -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.scalar import ScalarWrapper from strawberry.types.unset import UNSET @@ -182,7 +182,7 @@ def print_schema_directive( if hasattr(f_type, "_scalar_definition"): extras.types.add(cast("type", f_type)) - if isinstance(f_type, EnumDefinition): + if isinstance(f_type, StrawberryEnum): extras.types.add(cast("type", f_type)) return f" @{gql_directive.name}{params}" diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index 89d0d6c18f..3c3934c156 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -22,7 +22,7 @@ StrawberryObjectDefinition, WithStrawberryObjectDefinition, ) - from strawberry.types.enum import EnumDefinition + from strawberry.types.enum import StrawberryEnum from strawberry.types.graphql import OperationType from strawberry.types.scalar import ScalarDefinition from strawberry.types.union import StrawberryUnion @@ -82,7 +82,7 @@ def get_type_by_name( ) -> ( StrawberryObjectDefinition | ScalarDefinition - | EnumDefinition + | StrawberryEnum | StrawberryUnion | None ): diff --git a/strawberry/schema/compat.py b/strawberry/schema/compat.py index 257ebe06b6..d6f6194083 100644 --- a/strawberry/schema/compat.py +++ b/strawberry/schema/compat.py @@ -4,6 +4,7 @@ from strawberry.scalars import is_scalar as is_strawberry_scalar from strawberry.types.base import StrawberryType, has_object_definition +from strawberry.types.enum import StrawberryEnum # TypeGuard is only available in typing_extensions => 3.10, we don't want # to force updates to the typing_extensions package so we only use it when @@ -36,7 +37,10 @@ def is_scalar( def is_enum(type_: StrawberryType | type) -> TypeGuard[type]: - return hasattr(type_, "_enum_definition") + if hasattr(type_, "__strawberry_definition__"): + return isinstance(type_.__strawberry_definition__, StrawberryEnum) + + return False def is_schema_directive(type_: StrawberryType | type) -> TypeGuard[type]: diff --git a/strawberry/schema/name_converter.py b/strawberry/schema/name_converter.py index 599fc2bf90..12143b7bd4 100644 --- a/strawberry/schema/name_converter.py +++ b/strawberry/schema/name_converter.py @@ -11,7 +11,7 @@ StrawberryOptional, has_object_definition, ) -from strawberry.types.enum import EnumDefinition, EnumValue +from strawberry.types.enum import EnumValue, StrawberryEnum from strawberry.types.lazy_type import LazyType from strawberry.types.scalar import ScalarDefinition from strawberry.types.union import StrawberryUnion @@ -45,7 +45,7 @@ def from_type( ) -> str: if isinstance(type_, (StrawberryDirective, StrawberrySchemaDirective)): return self.from_directive(type_) - if isinstance(type_, EnumDefinition): # TODO: Replace with StrawberryEnum + if isinstance(type_, StrawberryEnum): return self.from_enum(type_) if isinstance(type_, StrawberryObjectDefinition): if type_.is_input: @@ -78,10 +78,10 @@ def from_input_object(self, input_type: StrawberryObjectDefinition) -> str: def from_interface(self, interface: StrawberryObjectDefinition) -> str: return self.from_object(interface) - def from_enum(self, enum: EnumDefinition) -> str: + def from_enum(self, enum: StrawberryEnum) -> str: return enum.name - def from_enum_value(self, enum: EnumDefinition, enum_value: EnumValue) -> str: + def from_enum_value(self, enum: StrawberryEnum, enum_value: EnumValue) -> str: return enum_value.name def from_directive( @@ -152,7 +152,7 @@ def get_name_from_type(self, type_: StrawberryType | type) -> str: if isinstance(type_, LazyType): type_ = type_.resolve_type() - if isinstance(type_, EnumDefinition): + if isinstance(type_, StrawberryEnum): name = type_.name elif isinstance(type_, StrawberryUnion): name = type_.graphql_name if type_.graphql_name else self.from_union(type_) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 48259566f4..2a3fcd292e 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -89,7 +89,7 @@ from strawberry.directive import StrawberryDirective from strawberry.types.base import StrawberryType - from strawberry.types.enum import EnumDefinition + from strawberry.types.enum import StrawberryEnum from strawberry.types.field import StrawberryField from strawberry.types.scalar import ScalarDefinition, ScalarWrapper from strawberry.types.union import StrawberryUnion @@ -438,7 +438,7 @@ def get_type_by_name( ) -> ( StrawberryObjectDefinition | ScalarDefinition - | EnumDefinition + | StrawberryEnum | StrawberryUnion | None ): @@ -913,7 +913,7 @@ def _resolve_node_ids(self) -> None: for concrete_type in self.schema_converter.type_map.values(): type_def = concrete_type.definition - # This can be a TypeDefinition, EnumDefinition, ScalarDefinition + # This can be a TypeDefinition, StrawberryEnum, ScalarDefinition # or UnionDefinition if not isinstance(type_def, StrawberryObjectDefinition): continue diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index c2697e0f9d..d83124c995 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -66,7 +66,7 @@ has_object_definition, ) from strawberry.types.cast import get_strawberry_type_cast -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.field import UNRESOLVED from strawberry.types.lazy_type import LazyType from strawberry.types.private import is_private @@ -156,7 +156,7 @@ def _get_thunk_mapping( class CustomGraphQLEnumType(GraphQLEnumType): def __init__( self, - enum: EnumDefinition, + enum: StrawberryEnum, *args: Any, **kwargs: Any, ) -> None: @@ -300,7 +300,7 @@ def from_argument(self, argument: StrawberryArgument) -> GraphQLArgument: }, ) - def from_enum(self, enum: EnumDefinition) -> CustomGraphQLEnumType: + def from_enum(self, enum: StrawberryEnum) -> CustomGraphQLEnumType: enum_name = self.config.name_converter.from_type(enum) assert enum_name is not None @@ -874,7 +874,7 @@ def from_type(self, type_: StrawberryType | type) -> GraphQLNullableType: if len(args) >= 2 and isinstance(args[1], StrawberryUnion): type_ = args[1] - if isinstance(type_, EnumDefinition): # TODO: Replace with StrawberryEnum + if isinstance(type_, StrawberryEnum): return self.from_enum(type_) if compat.is_input_type(type_): # TODO: Replace with StrawberryInputObject return self.from_input_object(type_) @@ -887,8 +887,8 @@ def from_type(self, type_: StrawberryType | type) -> GraphQLNullableType: return self.from_interface(type_definition) if has_object_definition(type_): return self.from_object(type_.__strawberry_definition__) - if compat.is_enum(type_): # TODO: Replace with StrawberryEnum - enum_definition: EnumDefinition = type_._enum_definition # type: ignore + if compat.is_enum(type_): + enum_definition: StrawberryEnum = type_.__strawberry_definition__ # type: ignore return self.from_enum(enum_definition) if isinstance(type_, StrawberryObjectDefinition): return self.from_object(type_) @@ -1023,14 +1023,14 @@ def validate_same_type_definition( if isinstance(second_type_definition, StrawberryObjectDefinition): first_origin = second_type_definition.origin - elif isinstance(second_type_definition, EnumDefinition): + elif isinstance(second_type_definition, StrawberryEnum): first_origin = second_type_definition.wrapped_cls else: first_origin = None if isinstance(first_type_definition, StrawberryObjectDefinition): second_origin = first_type_definition.origin - elif isinstance(first_type_definition, EnumDefinition): + elif isinstance(first_type_definition, StrawberryEnum): second_origin = first_type_definition.wrapped_cls else: second_origin = None diff --git a/strawberry/schema/types/concrete_type.py b/strawberry/schema/types/concrete_type.py index 4c8d5857f4..3c5fd04d48 100644 --- a/strawberry/schema/types/concrete_type.py +++ b/strawberry/schema/types/concrete_type.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from strawberry.types.base import StrawberryObjectDefinition - from strawberry.types.enum import EnumDefinition + from strawberry.types.enum import StrawberryEnum from strawberry.types.scalar import ScalarDefinition from strawberry.types.union import StrawberryUnion @@ -17,7 +17,7 @@ @dataclasses.dataclass class ConcreteType: definition: ( - StrawberryObjectDefinition | EnumDefinition | ScalarDefinition | StrawberryUnion + StrawberryObjectDefinition | StrawberryEnum | ScalarDefinition | StrawberryUnion ) implementation: GraphQLType diff --git a/strawberry/types/arguments.py b/strawberry/types/arguments.py index 06bc1b5d4e..1941629da5 100644 --- a/strawberry/types/arguments.py +++ b/strawberry/types/arguments.py @@ -20,7 +20,7 @@ StrawberryOptional, has_object_definition, ) -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.lazy_type import LazyType, StrawberryLazyReference from strawberry.types.maybe import Some from strawberry.types.unset import UNSET as _deprecated_UNSET # noqa: N811 @@ -154,16 +154,12 @@ def _is_leaf_type( if is_scalar(type_, scalar_registry): return True - if isinstance(type_, EnumDefinition): + if isinstance(type_, StrawberryEnum): return True if isinstance(type_, LazyType): return _is_leaf_type(type_.resolve_type(), scalar_registry) - if hasattr(type_, "_enum_definition"): - enum_definition: EnumDefinition = type_._enum_definition - return _is_leaf_type(enum_definition, scalar_registry) - return False @@ -241,8 +237,10 @@ def convert_argument( if isinstance(type_, LazyType): return convert_argument(value, type_.resolve_type(), scalar_registry, config) - if hasattr(type_, "_enum_definition"): - enum_definition: EnumDefinition = type_._enum_definition + if hasattr(type_, "__strawberry_definition__") and isinstance( + type_.__strawberry_definition__, StrawberryEnum + ): + enum_definition: StrawberryEnum = type_.__strawberry_definition__ return convert_argument(value, enum_definition, scalar_registry, config) if has_object_definition(type_): diff --git a/strawberry/types/base.py b/strawberry/types/base.py index d603b6d08c..aaf160c2ac 100644 --- a/strawberry/types/base.py +++ b/strawberry/types/base.py @@ -202,7 +202,7 @@ def has_object_definition( obj: Any, ) -> TypeGuard[type[WithStrawberryObjectDefinition]]: if hasattr(obj, "__strawberry_definition__"): - return True + return isinstance(obj.__strawberry_definition__, StrawberryObjectDefinition) # TODO: Generics remove dunder members here, so we inject it here. # Would be better to avoid it somehow. # https://github.com/python/cpython/blob/3a314f7c3df0dd7c37da7d12b827f169ee60e1ea/Lib/typing.py#L1152 @@ -210,7 +210,7 @@ def has_object_definition( concrete = obj.__origin__ if hasattr(concrete, "__strawberry_definition__"): obj.__strawberry_definition__ = concrete.__strawberry_definition__ - return True + return isinstance(obj.__strawberry_definition__, StrawberryObjectDefinition) return False @@ -419,13 +419,18 @@ def is_implemented_by(self, root: type[WithStrawberryObjectDefinition]) -> bool: continue # Check if the expected type matches the type found on the type_map - real_concrete_type = type(value) + from strawberry.types.enum import StrawberryEnum + + real_concrete_type: type | StrawberryEnum = type(value) # TODO: uniform type var map, at the moment we map object types # to their class (not to TypeDefinition) while we map enum to - # the EnumDefinition class. This is why we do this check here: - if hasattr(real_concrete_type, "_enum_definition"): - real_concrete_type = real_concrete_type._enum_definition + # the StrawberryEnum class. This is why we do this check here: + + if hasattr(real_concrete_type, "__strawberry_definition__") and isinstance( + real_concrete_type.__strawberry_definition__, StrawberryEnum + ): + real_concrete_type = real_concrete_type.__strawberry_definition__ if ( isinstance(expected_concrete_type, type) diff --git a/strawberry/types/enum.py b/strawberry/types/enum.py index 17f021b4ca..f433b4c7f1 100644 --- a/strawberry/types/enum.py +++ b/strawberry/types/enum.py @@ -1,14 +1,11 @@ import dataclasses from collections.abc import Callable, Iterable, Mapping from enum import EnumMeta -from typing import ( - Any, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, TypeVar, overload from strawberry.exceptions import ObjectIsNotAnEnumError from strawberry.types.base import StrawberryType +from strawberry.utils.deprecations import DEPRECATION_MESSAGES, DeprecatedDescriptor @dataclasses.dataclass @@ -21,7 +18,7 @@ class EnumValue: @dataclasses.dataclass -class EnumDefinition(StrawberryType): +class StrawberryEnum(StrawberryType): wrapped_cls: EnumMeta name: str values: list[EnumValue] @@ -148,7 +145,7 @@ def _process_enum( ) values.append(value) - cls._enum_definition = EnumDefinition( # type: ignore + cls.__strawberry_definition__ = StrawberryEnum( # type: ignore wrapped_cls=cls, name=name, values=values, @@ -156,6 +153,13 @@ def _process_enum( directives=directives, ) + # TODO: remove when deprecating _enum_definition + DeprecatedDescriptor( + DEPRECATION_MESSAGES._ENUM_DEFINITION, + cls.__strawberry_definition__, # type: ignore[attr-defined] + "_enum_definition", + ).inject(cls) + return cls @@ -235,4 +239,22 @@ def wrap(cls: EnumType) -> EnumType: return wrap(cls) -__all__ = ["EnumDefinition", "EnumValue", "EnumValueDefinition", "enum", "enum_value"] +# TODO: remove when deprecating _enum_definition +if TYPE_CHECKING: + from typing_extensions import deprecated + + @deprecated("Use StrawberryEnum instead") + class EnumDefinition(StrawberryEnum): ... + +else: + EnumDefinition = StrawberryEnum + + +__all__ = [ + "EnumDefinition", + "EnumValue", + "EnumValueDefinition", + "StrawberryEnum", + "enum", + "enum_value", +] diff --git a/strawberry/utils/deprecations.py b/strawberry/utils/deprecations.py index eee08676e3..23c89e9c09 100644 --- a/strawberry/utils/deprecations.py +++ b/strawberry/utils/deprecations.py @@ -8,6 +8,9 @@ class DEPRECATION_MESSAGES: # noqa: N801 _TYPE_DEFINITION = ( "_type_definition is deprecated, use __strawberry_definition__ instead" ) + _ENUM_DEFINITION = ( + "_enum_definition is deprecated, use __strawberry_definition__ instead" + ) class DeprecatedDescriptor: diff --git a/tests/benchmarks/test_stadium.py b/tests/benchmarks/test_stadium.py index 8dfcf8e57b..3b97f2bdfd 100644 --- a/tests/benchmarks/test_stadium.py +++ b/tests/benchmarks/test_stadium.py @@ -76,10 +76,10 @@ def create_stadium(seats_per_row: int = 250) -> Stadium: """Create a stadium with a configurable number of seats per row. Default configuration (250 seats/row) creates approximately 50,000 seats: - - North Stand: 12,500 seats (50 rows × 250 seats) - - South Stand: 12,500 seats (50 rows × 250 seats) - - East Stand: 10,000 seats (40 rows × 250 seats) - - West Stand: 10,000 seats (40 rows × 250 seats) + - North Stand: 12,500 seats (50 rows x 250 seats) + - South Stand: 12,500 seats (50 rows x 250 seats) + - East Stand: 10,000 seats (40 rows x 250 seats) + - West Stand: 10,000 seats (40 rows x 250 seats) """ stands = [] diff --git a/tests/enums/test_enum.py b/tests/enums/test_enum.py index c0dea47813..93c56981c0 100644 --- a/tests/enums/test_enum.py +++ b/tests/enums/test_enum.py @@ -5,7 +5,7 @@ import strawberry from strawberry.exceptions import ObjectIsNotAnEnumError from strawberry.types.base import get_object_definition -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum def test_basic_enum(): @@ -15,7 +15,7 @@ class IceCreamFlavour(Enum): STRAWBERRY = "strawberry" CHOCOLATE = "chocolate" - definition = IceCreamFlavour._enum_definition + definition = IceCreamFlavour.__strawberry_definition__ assert definition.name == "IceCreamFlavour" assert definition.description is None @@ -37,7 +37,7 @@ class IceCreamFlavour(Enum): STRAWBERRY = "strawberry" CHOCOLATE = "chocolate" - definition = IceCreamFlavour._enum_definition + definition = IceCreamFlavour.__strawberry_definition__ assert definition.name == "Flavour" assert definition.description == "example" @@ -58,7 +58,7 @@ def flavour_available(self, flavour: IceCreamFlavour) -> bool: field = Query.__strawberry_definition__.fields[0] - assert isinstance(field.arguments[0].type, EnumDefinition) + assert isinstance(field.arguments[0].type, StrawberryEnum) @pytest.mark.raises_strawberry_exception( @@ -80,7 +80,7 @@ class IceCreamFlavour(Enum): ) CHOCOLATE = "chocolate" - definition = IceCreamFlavour._enum_definition + definition = IceCreamFlavour.__strawberry_definition__ assert definition.values[0].name == "VANILLA" assert definition.values[0].value == "vanilla" @@ -105,7 +105,7 @@ class IceCreamFlavour(Enum): ) CHOCOLATE = "chocolate" - definition = IceCreamFlavour._enum_definition + definition = IceCreamFlavour.__strawberry_definition__ assert definition.values[0].name == "VANILLA" assert definition.values[0].value == "vanilla" diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 3fe101a68b..c8e87a3755 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -13,7 +13,7 @@ StrawberryObjectDefinition, StrawberryOptional, ) -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.union import StrawberryUnion @@ -575,7 +575,7 @@ class UserType: assert field1.type is int assert field2.python_name == "kind" - assert isinstance(field2.type, EnumDefinition) + assert isinstance(field2.type, StrawberryEnum) assert field2.type.wrapped_cls is UserKind diff --git a/tests/objects/generics/test_names.py b/tests/objects/generics/test_names.py index 54b3edb8db..a54520b72f 100644 --- a/tests/objects/generics/test_names.py +++ b/tests/objects/generics/test_names.py @@ -6,7 +6,7 @@ import strawberry from strawberry.schema.config import StrawberryConfig from strawberry.types.base import StrawberryList, StrawberryOptional -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.lazy_type import LazyType from strawberry.types.union import StrawberryUnion @@ -15,7 +15,7 @@ V = TypeVar("V") -Enum = EnumDefinition(None, name="Enum", values=[], description=None) # type: ignore +Enum = StrawberryEnum(None, name="Enum", values=[], description=None) # type: ignore CustomInt = strawberry.scalar(NewType("CustomInt", int)) diff --git a/tests/schema/test_extensions.py b/tests/schema/test_extensions.py index 8b56d58a8e..4c8b3f4500 100644 --- a/tests/schema/test_extensions.py +++ b/tests/schema/test_extensions.py @@ -85,15 +85,16 @@ class Query: graphql_thing_type = cast("GraphQLEnumType", graphql_schema.get_type("ThingType")) assert ( - graphql_thing_type.extensions[DEFINITION_BACKREF] is ThingType._enum_definition + graphql_thing_type.extensions[DEFINITION_BACKREF] + is ThingType.__strawberry_definition__ ) assert ( graphql_thing_type.values["JSON"].extensions[DEFINITION_BACKREF] - is ThingType._enum_definition.values[0] + is ThingType.__strawberry_definition__.values[0] ) assert ( graphql_thing_type.values["STR"].extensions[DEFINITION_BACKREF] - is ThingType._enum_definition.values[1] + is ThingType.__strawberry_definition__.values[1] ) diff --git a/tests/schema/test_name_converter.py b/tests/schema/test_name_converter.py index 4b13d994a7..611ea91d6d 100644 --- a/tests/schema/test_name_converter.py +++ b/tests/schema/test_name_converter.py @@ -9,7 +9,7 @@ from strawberry.schema_directive import Location, StrawberrySchemaDirective from strawberry.types.arguments import StrawberryArgument from strawberry.types.base import StrawberryObjectDefinition, StrawberryType -from strawberry.types.enum import EnumDefinition, EnumValue +from strawberry.types.enum import EnumValue, StrawberryEnum from strawberry.types.field import StrawberryField from strawberry.types.scalar import ScalarDefinition from strawberry.types.union import StrawberryUnion @@ -53,10 +53,10 @@ def from_input_object(self, input_type: StrawberryObjectDefinition) -> str: def from_object(self, object_type: StrawberryObjectDefinition) -> str: return super().from_object(object_type) + self.suffix - def from_enum(self, enum: EnumDefinition) -> str: + def from_enum(self, enum: StrawberryEnum) -> str: return super().from_enum(enum) + self.suffix - def from_enum_value(self, enum: EnumDefinition, enum_value: EnumValue) -> str: + def from_enum_value(self, enum: StrawberryEnum, enum_value: EnumValue) -> str: return super().from_enum_value(enum, enum_value) + self.suffix diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 1d9ac3abf8..23d3ac1b40 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -1,3 +1,5 @@ +from enum import Enum + import pytest import strawberry @@ -9,10 +11,15 @@ class A: a: int +@strawberry.enum +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + def test_type_definition_is_aliased(): - with pytest.warns( - match="_type_definition is deprecated, use __strawberry_definition__ instead" - ): + with pytest.warns(match=DEPRECATION_MESSAGES._TYPE_DEFINITION): assert A.__strawberry_definition__ is A._type_definition @@ -22,6 +29,24 @@ def test_get_warns(): def test_can_import_type_definition(): - from strawberry.types.base import TypeDefinition + from strawberry.types.base import StrawberryObjectDefinition, TypeDefinition assert TypeDefinition + assert TypeDefinition is StrawberryObjectDefinition + + +def test_enum_definition_is_aliased(): + with pytest.warns(match=DEPRECATION_MESSAGES._ENUM_DEFINITION): + assert Color.__strawberry_definition__ is Color._enum_definition + + +def test_enum_get_warns(): + with pytest.warns(match=DEPRECATION_MESSAGES._ENUM_DEFINITION): + assert Color._enum_definition.name == "Color" + + +def test_can_import_enum_definition(): + from strawberry.types.enum import EnumDefinition, StrawberryEnum + + assert EnumDefinition + assert EnumDefinition is StrawberryEnum diff --git a/tests/types/resolving/test_enums.py b/tests/types/resolving/test_enums.py index 1c74b1b8c1..9ba0c948ef 100644 --- a/tests/types/resolving/test_enums.py +++ b/tests/types/resolving/test_enums.py @@ -16,4 +16,4 @@ class NumaNuma(Enum): resolved = annotation.resolve() # TODO: Remove reference to .enum_definition with StrawberryEnum - assert resolved is NumaNuma._enum_definition + assert resolved is NumaNuma.__strawberry_definition__ diff --git a/tests/types/resolving/test_generics.py b/tests/types/resolving/test_generics.py index bf5529c081..1ae0864ad2 100644 --- a/tests/types/resolving/test_generics.py +++ b/tests/types/resolving/test_generics.py @@ -13,7 +13,7 @@ get_object_definition, has_object_definition, ) -from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import StrawberryEnum from strawberry.types.field import StrawberryField from strawberry.types.union import StrawberryUnion @@ -114,8 +114,8 @@ class GenericForEnum(Generic[T]): assert isinstance(resolved.__strawberry_definition__, StrawberryObjectDefinition) generic_slot_field: StrawberryField = resolved.__strawberry_definition__.fields[0] - assert isinstance(generic_slot_field.type, EnumDefinition) - assert generic_slot_field.type is VehicleMake._enum_definition + assert isinstance(generic_slot_field.type, StrawberryEnum) + assert generic_slot_field.type is VehicleMake.__strawberry_definition__ def test_cant_create_concrete_of_non_strawberry_object(): diff --git a/tests/types/test_argument_types.py b/tests/types/test_argument_types.py index 74a3137931..ddf1086ee9 100644 --- a/tests/types/test_argument_types.py +++ b/tests/types/test_argument_types.py @@ -21,8 +21,8 @@ def set_locale(locale: Locale) -> bool: return True argument = set_locale.arguments[0] - # TODO: Remove reference to ._enum_definition with StrawberryEnum - assert argument.type is Locale._enum_definition + # TODO: Remove reference to .__strawberry_definition__ with StrawberryEnum + assert argument.type is Locale.__strawberry_definition__ def test_forward_reference(): diff --git a/tests/types/test_field_types.py b/tests/types/test_field_types.py index f50e9d16db..10302717ec 100644 --- a/tests/types/test_field_types.py +++ b/tests/types/test_field_types.py @@ -16,8 +16,8 @@ class Egnum(Enum): annotation = StrawberryAnnotation(Egnum) field = StrawberryField(type_annotation=annotation) - # TODO: Remove reference to ._enum_definition with StrawberryEnum - assert field.type is Egnum._enum_definition + # TODO: Remove reference to .__strawberry_definition__ with StrawberryEnum + assert field.type is Egnum.__strawberry_definition__ def test_forward_reference(): diff --git a/tests/types/test_object_types.py b/tests/types/test_object_types.py index 93385df7a4..c773783549 100644 --- a/tests/types/test_object_types.py +++ b/tests/types/test_object_types.py @@ -23,8 +23,8 @@ class Animal: field: StrawberryField = get_object_definition(Animal).fields[0] - # TODO: Remove reference to ._enum_definition with StrawberryEnum - assert field.type is Count._enum_definition + # TODO: Remove reference to .__strawberry_definition__ with StrawberryEnum + assert field.type is Count.__strawberry_definition__ def test_forward_reference(): diff --git a/tests/types/test_resolver_types.py b/tests/types/test_resolver_types.py index 2c7d8708f4..9f91101371 100644 --- a/tests/types/test_resolver_types.py +++ b/tests/types/test_resolver_types.py @@ -18,8 +18,8 @@ def get_spoken_language() -> Language: return Language.ENGLISH resolver = StrawberryResolver(get_spoken_language) - # TODO: Remove reference to ._enum_definition with StrawberryEnum - assert resolver.type is Language._enum_definition + # TODO: Remove reference to .__strawberry_definition__ with StrawberryEnum + assert resolver.type is Language.__strawberry_definition__ def test_forward_references():