Skip to content

Commit 3f35ebb

Browse files
authored
Add support for converters with TypeVars on generic attrs classes (#14908)
When creating generic classes using `attrs`, converters with type vars are not properly integrated into the generated `__init__`: ```python from typing import TypeVar, Generic, List, Iterable, Iterator import attr T = TypeVar('T') def int_gen() -> Iterator[int]: yield 1 def list_converter(x: Iterable[T]) -> List[T]: return list(x) @attr.s(auto_attribs=True) class A(Generic[T]): x: List[T] = attr.ib(converter=list_converter) y: T = attr.ib() a1 = A([1], 2) # E: Argument 1 to "A" has incompatible type "Iterator[int]"; expected "Iterable[T]" ``` This MR fixes the bug by copying type vars from the field/attrib into the type extracted from the converter.
1 parent 11c2c6d commit 3f35ebb

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

mypy/plugins/attrs.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from typing_extensions import Final, Literal
77

88
import mypy.plugin # To avoid circular imports.
9+
from mypy.applytype import apply_generic_arguments
910
from mypy.checker import TypeChecker
1011
from mypy.errorcodes import LITERAL_REQ
12+
from mypy.expandtype import expand_type
1113
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
1214
from mypy.messages import format_type_bare
1315
from mypy.nodes import (
@@ -23,6 +25,7 @@
2325
Decorator,
2426
Expression,
2527
FuncDef,
28+
IndexExpr,
2629
JsonDict,
2730
LambdaExpr,
2831
ListExpr,
@@ -34,6 +37,7 @@
3437
SymbolTableNode,
3538
TempNode,
3639
TupleExpr,
40+
TypeApplication,
3741
TypeInfo,
3842
TypeVarExpr,
3943
Var,
@@ -49,7 +53,7 @@
4953
deserialize_and_fixup_type,
5054
)
5155
from mypy.server.trigger import make_wildcard_trigger
52-
from mypy.typeops import make_simplified_union, map_type_from_supertype
56+
from mypy.typeops import get_type_vars, make_simplified_union, map_type_from_supertype
5357
from mypy.types import (
5458
AnyType,
5559
CallableType,
@@ -85,8 +89,9 @@
8589
class Converter:
8690
"""Holds information about a `converter=` argument"""
8791

88-
def __init__(self, init_type: Type | None = None) -> None:
92+
def __init__(self, init_type: Type | None = None, ret_type: Type | None = None) -> None:
8993
self.init_type = init_type
94+
self.ret_type = ret_type
9095

9196

9297
class Attribute:
@@ -115,11 +120,20 @@ def __init__(
115120
def argument(self, ctx: mypy.plugin.ClassDefContext) -> Argument:
116121
"""Return this attribute as an argument to __init__."""
117122
assert self.init
118-
119123
init_type: Type | None = None
120124
if self.converter:
121125
if self.converter.init_type:
122126
init_type = self.converter.init_type
127+
if init_type and self.init_type and self.converter.ret_type:
128+
# The converter return type should be the same type as the attribute type.
129+
# Copy type vars from attr type to converter.
130+
converter_vars = get_type_vars(self.converter.ret_type)
131+
init_vars = get_type_vars(self.init_type)
132+
if converter_vars and len(converter_vars) == len(init_vars):
133+
variables = {
134+
binder.id: arg for binder, arg in zip(converter_vars, init_vars)
135+
}
136+
init_type = expand_type(init_type, variables)
123137
else:
124138
ctx.api.fail("Cannot determine __init__ type from converter", self.context)
125139
init_type = AnyType(TypeOfAny.from_error)
@@ -653,6 +667,26 @@ def _parse_converter(
653667
from mypy.checkmember import type_object_type # To avoid import cycle.
654668

655669
converter_type = type_object_type(converter_expr.node, ctx.api.named_type)
670+
elif (
671+
isinstance(converter_expr, IndexExpr)
672+
and isinstance(converter_expr.analyzed, TypeApplication)
673+
and isinstance(converter_expr.base, RefExpr)
674+
and isinstance(converter_expr.base.node, TypeInfo)
675+
):
676+
# The converter is a generic type.
677+
from mypy.checkmember import type_object_type # To avoid import cycle.
678+
679+
converter_type = type_object_type(converter_expr.base.node, ctx.api.named_type)
680+
if isinstance(converter_type, CallableType):
681+
converter_type = apply_generic_arguments(
682+
converter_type,
683+
converter_expr.analyzed.types,
684+
ctx.api.msg.incompatible_typevar_value,
685+
converter_type,
686+
)
687+
else:
688+
converter_type = None
689+
656690
if isinstance(converter_expr, LambdaExpr):
657691
# TODO: should we send a fail if converter_expr.min_args > 1?
658692
converter_info.init_type = AnyType(TypeOfAny.unannotated)
@@ -671,6 +705,8 @@ def _parse_converter(
671705
converter_type = get_proper_type(converter_type)
672706
if isinstance(converter_type, CallableType) and converter_type.arg_types:
673707
converter_info.init_type = converter_type.arg_types[0]
708+
if not is_attr_converters_optional:
709+
converter_info.ret_type = converter_type.ret_type
674710
elif isinstance(converter_type, Overloaded):
675711
types: list[Type] = []
676712
for item in converter_type.items:

test-data/unit/check-attr.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,56 @@ A([1], '2') # E: Cannot infer type argument 1 of "A"
469469

470470
[builtins fixtures/list.pyi]
471471

472+
[case testAttrsGenericWithConverter]
473+
from typing import TypeVar, Generic, List, Iterable, Iterator, Callable
474+
import attr
475+
T = TypeVar('T')
476+
477+
def int_gen() -> Iterator[int]:
478+
yield 1
479+
480+
def list_converter(x: Iterable[T]) -> List[T]:
481+
return list(x)
482+
483+
@attr.s(auto_attribs=True)
484+
class A(Generic[T]):
485+
x: List[T] = attr.ib(converter=list_converter)
486+
y: T = attr.ib()
487+
def foo(self) -> List[T]:
488+
return [self.y]
489+
def bar(self) -> T:
490+
return self.x[0]
491+
def problem(self) -> T:
492+
return self.x # E: Incompatible return value type (got "List[T]", expected "T")
493+
reveal_type(A) # N: Revealed type is "def [T] (x: typing.Iterable[T`1], y: T`1) -> __main__.A[T`1]"
494+
a1 = A([1], 2)
495+
reveal_type(a1) # N: Revealed type is "__main__.A[builtins.int]"
496+
reveal_type(a1.x) # N: Revealed type is "builtins.list[builtins.int]"
497+
reveal_type(a1.y) # N: Revealed type is "builtins.int"
498+
499+
a2 = A(int_gen(), 2)
500+
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
501+
reveal_type(a2.x) # N: Revealed type is "builtins.list[builtins.int]"
502+
reveal_type(a2.y) # N: Revealed type is "builtins.int"
503+
504+
505+
def get_int() -> int:
506+
return 1
507+
508+
class Other(Generic[T]):
509+
def __init__(self, x: T) -> None:
510+
pass
511+
512+
@attr.s(auto_attribs=True)
513+
class B(Generic[T]):
514+
x: Other[Callable[..., T]] = attr.ib(converter=Other[Callable[..., T]])
515+
516+
b1 = B(get_int)
517+
reveal_type(b1) # N: Revealed type is "__main__.B[builtins.int]"
518+
reveal_type(b1.x) # N: Revealed type is "__main__.Other[def (*Any, **Any) -> builtins.int]"
519+
520+
[builtins fixtures/list.pyi]
521+
472522

473523
[case testAttrsUntypedGenericInheritance]
474524
from typing import Generic, TypeVar

0 commit comments

Comments
 (0)