Skip to content

Commit 0bf5418

Browse files
committed
Allow method return types to be changed when the overridden type is NoReturn.
PiperOrigin-RevId: 447558672
1 parent fb3d35c commit 0bf5418

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

pytype/overriding_checks.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,10 @@ def _check_default_values(method_signature, base_signature):
327327
base_default = abstract_utils.get_atomic_value(base_default_value)
328328
method_default = abstract_utils.get_atomic_value(method_default_value)
329329

330-
# Unsolvable or Unknown matches anything.
331-
if isinstance(base_default, (abstract.Unsolvable, abstract.Unknown)):
330+
# Unsolvable, Unknown, or Empty matches anything.
331+
if isinstance(base_default, abstract.AMBIGUOUS_OR_EMPTY):
332332
continue
333-
if isinstance(method_default, (abstract.Unsolvable, abstract.Unknown)):
333+
if isinstance(method_default, abstract.AMBIGUOUS_OR_EMPTY):
334334
continue
335335

336336
if base_default != method_default:
@@ -351,6 +351,10 @@ def _check_return_types(method_signature, base_signature, is_subtype):
351351
# Return type not annotated in either of the two methods.
352352
return None
353353

354+
if (isinstance(base_return_type, abstract.AMBIGUOUS_OR_EMPTY) or
355+
isinstance(method_return_type, abstract.AMBIGUOUS_OR_EMPTY)):
356+
return None
357+
354358
# Return type of the overriding method must be a subtype of the
355359
# return type of the overridden method.
356360
if not is_subtype(method_return_type, base_return_type):

pytype/tests/test_overriding.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,19 @@ def f(self) -> str: # signature-mismatch
198198
return ''
199199
""")
200200

201+
def test_return_type_matches_empty(self):
202+
with self.DepTree([("foo.py", """
203+
class Foo:
204+
def f(self):
205+
raise NotImplementedError()
206+
""")]):
207+
self.Check("""
208+
import foo
209+
class Bar(foo.Foo):
210+
def f(self) -> None:
211+
pass
212+
""")
213+
201214
def test_pytdclass_signature_match(self):
202215
self.Check("""
203216
class Foo(list):

0 commit comments

Comments
 (0)