Skip to content

Commit 5e4d097

Browse files
authored
Fix strict optional handling in dataclasses (#15571)
There were few cases when someone forgot to call `strict_optional_set()` in dataclasses plugin, let's move the calls directly to two places where they are needed for typeops. This may cause a tiny perf regression, but is much more robust in terms of preventing bugs.
1 parent 2e9c9b4 commit 5e4d097

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

mypy/plugins/dataclasses.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(
104104
info: TypeInfo,
105105
kw_only: bool,
106106
is_neither_frozen_nor_nonfrozen: bool,
107+
api: SemanticAnalyzerPluginInterface,
107108
) -> None:
108109
self.name = name
109110
self.alias = alias
@@ -116,6 +117,7 @@ def __init__(
116117
self.info = info
117118
self.kw_only = kw_only
118119
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
120+
self._api = api
119121

120122
def to_argument(self, current_info: TypeInfo) -> Argument:
121123
arg_kind = ARG_POS
@@ -138,7 +140,10 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]:
138140
# however this plugin is called very late, so all types should be fully ready.
139141
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
140142
# we serialize attributes).
141-
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
143+
with state.strict_optional_set(self._api.options.strict_optional):
144+
return expand_type(
145+
self.type, {self.info.self_type.id: fill_typevars(current_info)}
146+
)
142147
return self.type
143148

144149
def to_var(self, current_info: TypeInfo) -> Var:
@@ -165,13 +170,14 @@ def deserialize(
165170
) -> DataclassAttribute:
166171
data = data.copy()
167172
typ = deserialize_and_fixup_type(data.pop("type"), api)
168-
return cls(type=typ, info=info, **data)
173+
return cls(type=typ, info=info, **data, api=api)
169174

170175
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
171176
"""Expands type vars in the context of a subtype when an attribute is inherited
172177
from a generic super type."""
173178
if self.type is not None:
174-
self.type = map_type_from_supertype(self.type, sub_type, self.info)
179+
with state.strict_optional_set(self._api.options.strict_optional):
180+
self.type = map_type_from_supertype(self.type, sub_type, self.info)
175181

176182

177183
class DataclassTransformer:
@@ -230,12 +236,11 @@ def transform(self) -> bool:
230236
and ("__init__" not in info.names or info.names["__init__"].plugin_generated)
231237
and attributes
232238
):
233-
with state.strict_optional_set(self._api.options.strict_optional):
234-
args = [
235-
attr.to_argument(info)
236-
for attr in attributes
237-
if attr.is_in_init and not self._is_kw_only_type(attr.type)
238-
]
239+
args = [
240+
attr.to_argument(info)
241+
for attr in attributes
242+
if attr.is_in_init and not self._is_kw_only_type(attr.type)
243+
]
239244

240245
if info.fallback_to_any:
241246
# Make positional args optional since we don't know their order.
@@ -355,8 +360,7 @@ def transform(self) -> bool:
355360
self._add_dataclass_fields_magic_attribute()
356361

357362
if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
358-
with state.strict_optional_set(self._api.options.strict_optional):
359-
self._add_internal_replace_method(attributes)
363+
self._add_internal_replace_method(attributes)
360364
if "__post_init__" in info.names:
361365
self._add_internal_post_init_method(attributes)
362366

@@ -546,8 +550,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
546550
# TODO: We shouldn't be performing type operations during the main
547551
# semantic analysis pass, since some TypeInfo attributes might
548552
# still be in flux. This should be performed in a later phase.
549-
with state.strict_optional_set(self._api.options.strict_optional):
550-
attr.expand_typevar_from_subtype(cls.info)
553+
attr.expand_typevar_from_subtype(cls.info)
551554
found_attrs[name] = attr
552555

553556
sym_node = cls.info.names.get(name)
@@ -693,6 +696,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
693696
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
694697
cls.info
695698
),
699+
api=self._api,
696700
)
697701

698702
all_attrs = list(found_attrs.values())

test-data/unit/pythoneval.test

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,7 +2094,6 @@ grouped = groupby(pairs, key=fst)
20942094
[out]
20952095

20962096
[case testDataclassReplaceOptional]
2097-
# flags: --strict-optional
20982097
from dataclasses import dataclass, replace
20992098
from typing import Optional
21002099

@@ -2107,5 +2106,18 @@ reveal_type(a)
21072106
a2 = replace(a, x=None) # OK
21082107
reveal_type(a2)
21092108
[out]
2110-
_testDataclassReplaceOptional.py:10: note: Revealed type is "_testDataclassReplaceOptional.A"
2111-
_testDataclassReplaceOptional.py:12: note: Revealed type is "_testDataclassReplaceOptional.A"
2109+
_testDataclassReplaceOptional.py:9: note: Revealed type is "_testDataclassReplaceOptional.A"
2110+
_testDataclassReplaceOptional.py:11: note: Revealed type is "_testDataclassReplaceOptional.A"
2111+
2112+
[case testDataclassStrictOptionalAlwaysSet]
2113+
from dataclasses import dataclass
2114+
from typing import Callable, Optional
2115+
2116+
@dataclass
2117+
class Description:
2118+
name_fn: Callable[[Optional[int]], Optional[str]]
2119+
2120+
def f(d: Description) -> None:
2121+
reveal_type(d.name_fn)
2122+
[out]
2123+
_testDataclassStrictOptionalAlwaysSet.py:9: note: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]"

0 commit comments

Comments
 (0)