Skip to content

Commit f930fba

Browse files
committed
Raise an exception for typed dict errors rather than logging them immediately.
This lets us filter out bindings that aren't visible. I also made a couple of tangentially related minor improvements: * Changed the typed dict error message to say that keys should be "constant strings", not just "strings", since the error is also reported when the key is a string that can't be resolved to a pyval. * Silenced a weird lint error. Resolves #1186. PiperOrigin-RevId: 447085010
1 parent 91e5aa6 commit f930fba

File tree

5 files changed

+42
-20
lines changed

5 files changed

+42
-20
lines changed

pytype/abstract/function.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -818,8 +818,7 @@ def call_function(ctx,
818818
if (nodes and not ctx.options.strict_parameter_checks) or not error:
819819
return node, result
820820
elif fallback_to_unsolvable:
821-
if not isinstance(error, DictKeyMissing):
822-
ctx.errorlog.invalid_function_call(ctx.vm.stack(func_var.data[0]), error)
821+
ctx.errorlog.invalid_function_call(ctx.vm.stack(func_var.data[0]), error)
823822
return node, result
824823
else:
825824
# We were called by something that does its own error handling.

pytype/errors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,11 @@ def invalid_function_call(self, stack, error):
963963
stack, error.name, error.bad_call, error.duplicate)
964964
elif isinstance(error, function.UndefinedParameterError):
965965
self.name_error(stack, error.name)
966+
elif isinstance(error, typed_dict_overlay.TypedDictKeyMissing):
967+
self.typed_dict_error(stack, error.typed_dict, error.name)
968+
elif isinstance(error, function.DictKeyMissing):
969+
# We don't report DictKeyMissing because the false positive rate is high.
970+
pass
966971
else:
967972
raise AssertionError(error)
968973

@@ -1346,7 +1351,8 @@ def typed_dict_error(self, stack, obj, name):
13461351
if name:
13471352
err_msg = f"TypedDict {obj.class_name} does not contain key {name}"
13481353
else:
1349-
err_msg = f"TypedDict {obj.class_name} requires all keys to be strings"
1354+
err_msg = (f"TypedDict {obj.class_name} requires all keys to be constant "
1355+
"strings")
13501356
self.error(stack, err_msg)
13511357

13521358
@_error_name("final-error")

pytype/overlays/typed_dict.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44

5-
from typing import Any, Dict, Set
5+
from typing import Any, Dict, Optional, Set
66

77
from pytype.abstract import abstract
88
from pytype.abstract import abstract_utils
@@ -11,6 +11,13 @@
1111
from pytype.pytd import pytd
1212

1313

14+
class TypedDictKeyMissing(function.DictKeyMissing):
15+
16+
def __init__(self, typed_dict: "TypedDict", key: Optional[str]):
17+
super().__init__(key)
18+
self.typed_dict = typed_dict
19+
20+
1421
@dataclasses.dataclass
1522
class TypedDictProperties:
1623
"""Collection of typed dict properties passed between various stages."""
@@ -29,7 +36,7 @@ def optional(self):
2936
return self.keys - self.required
3037

3138
def add(self, k, v, total):
32-
self.fields[k] = v
39+
self.fields[k] = v # pylint: disable=unsupported-assignment-operation
3340
if total:
3441
self.required.add(k)
3542

@@ -220,13 +227,10 @@ def __repr__(self):
220227

221228
def _check_str_key(self, name):
222229
if name not in self.fields:
223-
self.ctx.errorlog.typed_dict_error(self.ctx.vm.frames, self, name)
224-
return False
225-
return True
230+
raise TypedDictKeyMissing(self, name)
226231

227232
def _check_str_key_value(self, node, name, value_var):
228-
if not self._check_str_key(name):
229-
return
233+
self._check_str_key(name)
230234
typ = abstract_utils.get_atomic_value(self.fields[name])
231235
bad, _ = self.ctx.matcher(node).bad_matches(value_var, typ)
232236
for view, error_details in bad:
@@ -240,10 +244,9 @@ def _check_key(self, name_var):
240244
"""Check that key is in the typed dict."""
241245
try:
242246
name = abstract_utils.get_atomic_python_constant(name_var, str)
243-
except abstract_utils.ConversionError:
244-
self.ctx.errorlog.typed_dict_error(self.ctx.vm.frames, self, name=None)
245-
return False
246-
return self._check_str_key(name)
247+
except abstract_utils.ConversionError as e:
248+
raise TypedDictKeyMissing(self, None) from e
249+
self._check_str_key(name)
247250

248251
def _check_value(self, node, name_var, value_var):
249252
"""Check that value has the right type."""
@@ -254,13 +257,12 @@ def _check_value(self, node, name_var, value_var):
254257
def getitem_slot(self, node, name_var):
255258
# A typed dict getitem should have a concrete string arg. If we have a var
256259
# with multiple bindings just fall back to Any.
257-
if not self._check_key(name_var):
258-
return node, self.ctx.new_unsolvable(node)
260+
self._check_key(name_var)
259261
return super().getitem_slot(node, name_var)
260262

261263
def setitem_slot(self, node, name_var, value_var):
262-
if self._check_key(name_var):
263-
self._check_value(node, name_var, value_var)
264+
self._check_key(name_var)
265+
self._check_value(node, name_var, value_var)
264266
return super().setitem_slot(node, name_var, value_var)
265267

266268
def set_str_item(self, node, name, value_var):

pytype/tests/test_typed_dict.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ def f(x: Foo):
297297
f(x)
298298
""")
299299

300+
def test_key_existence_check(self):
301+
self.Check("""
302+
from typing import Union
303+
from typing_extensions import TypedDict
304+
305+
class Foo(TypedDict):
306+
a: int
307+
class Bar(TypedDict):
308+
b: str
309+
class Baz(TypedDict):
310+
c: Union[Foo, Bar]
311+
312+
baz: Baz = {'c': {'a': 0}}
313+
assert 'a' in baz['c']
314+
print(baz['c']['a'])
315+
""")
316+
300317

301318
_SINGLE = """
302319
from typing import TypedDict

pytype/vm_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,6 @@ def call_binary_operator(state, name, x, y, report_errors, ctx):
736736
if ctx.options.report_errors:
737737
ctx.errorlog.unsupported_operands(ctx.vm.frames, name, x, y)
738738
result = ctx.new_unsolvable(state.node)
739-
elif isinstance(error, function.DictKeyMissing):
740-
state, result = error.get_return(state)
741739
else:
742740
if ctx.options.report_errors:
743741
ctx.errorlog.invalid_function_call(ctx.vm.frames, error)

0 commit comments

Comments
 (0)