Skip to content

Commit 679b186

Browse files
authored
Merge pull request #1201 from google/google_sync
Google sync
2 parents 8fdc7ef + f930fba commit 679b186

File tree

8 files changed

+67
-23
lines changed

8 files changed

+67
-23
lines changed

pytype/abstract/function.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -818,8 +818,7 @@ def call_function(ctx,
818818
if (nodes and not ctx.options.strict_parameter_checks) or not error:
819819
return node, result
820820
elif fallback_to_unsolvable:
821-
if not isinstance(error, DictKeyMissing):
822-
ctx.errorlog.invalid_function_call(ctx.vm.stack(func_var.data[0]), error)
821+
ctx.errorlog.invalid_function_call(ctx.vm.stack(func_var.data[0]), error)
823822
return node, result
824823
else:
825824
# We were called by something that does its own error handling.

pytype/errors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,11 @@ def invalid_function_call(self, stack, error):
963963
stack, error.name, error.bad_call, error.duplicate)
964964
elif isinstance(error, function.UndefinedParameterError):
965965
self.name_error(stack, error.name)
966+
elif isinstance(error, typed_dict_overlay.TypedDictKeyMissing):
967+
self.typed_dict_error(stack, error.typed_dict, error.name)
968+
elif isinstance(error, function.DictKeyMissing):
969+
# We don't report DictKeyMissing because the false positive rate is high.
970+
pass
966971
else:
967972
raise AssertionError(error)
968973

@@ -1346,7 +1351,8 @@ def typed_dict_error(self, stack, obj, name):
13461351
if name:
13471352
err_msg = f"TypedDict {obj.class_name} does not contain key {name}"
13481353
else:
1349-
err_msg = f"TypedDict {obj.class_name} requires all keys to be strings"
1354+
err_msg = (f"TypedDict {obj.class_name} requires all keys to be constant "
1355+
"strings")
13501356
self.error(stack, err_msg)
13511357

13521358
@_error_name("final-error")

pytype/overlays/collections_overlay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ class ABCOverlay(typing_overlay.Redirect):
3434
"""A custom overlay for the 'collections.abc' module."""
3535

3636
def __init__(self, ctx):
37-
super().__init__("collections.abc", {}, ctx)
37+
super().__init__("collections.abc", {"Set": "typing.AbstractSet"}, ctx)

pytype/overlays/typed_dict.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44

5-
from typing import Any, Dict, Set
5+
from typing import Any, Dict, Optional, Set
66

77
from pytype.abstract import abstract
88
from pytype.abstract import abstract_utils
@@ -11,6 +11,13 @@
1111
from pytype.pytd import pytd
1212

1313

14+
class TypedDictKeyMissing(function.DictKeyMissing):
15+
16+
def __init__(self, typed_dict: "TypedDict", key: Optional[str]):
17+
super().__init__(key)
18+
self.typed_dict = typed_dict
19+
20+
1421
@dataclasses.dataclass
1522
class TypedDictProperties:
1623
"""Collection of typed dict properties passed between various stages."""
@@ -29,7 +36,7 @@ def optional(self):
2936
return self.keys - self.required
3037

3138
def add(self, k, v, total):
32-
self.fields[k] = v
39+
self.fields[k] = v # pylint: disable=unsupported-assignment-operation
3340
if total:
3441
self.required.add(k)
3542

@@ -220,13 +227,10 @@ def __repr__(self):
220227

221228
def _check_str_key(self, name):
222229
if name not in self.fields:
223-
self.ctx.errorlog.typed_dict_error(self.ctx.vm.frames, self, name)
224-
return False
225-
return True
230+
raise TypedDictKeyMissing(self, name)
226231

227232
def _check_str_key_value(self, node, name, value_var):
228-
if not self._check_str_key(name):
229-
return
233+
self._check_str_key(name)
230234
typ = abstract_utils.get_atomic_value(self.fields[name])
231235
bad, _ = self.ctx.matcher(node).bad_matches(value_var, typ)
232236
for view, error_details in bad:
@@ -240,10 +244,9 @@ def _check_key(self, name_var):
240244
"""Check that key is in the typed dict."""
241245
try:
242246
name = abstract_utils.get_atomic_python_constant(name_var, str)
243-
except abstract_utils.ConversionError:
244-
self.ctx.errorlog.typed_dict_error(self.ctx.vm.frames, self, name=None)
245-
return False
246-
return self._check_str_key(name)
247+
except abstract_utils.ConversionError as e:
248+
raise TypedDictKeyMissing(self, None) from e
249+
self._check_str_key(name)
247250

248251
def _check_value(self, node, name_var, value_var):
249252
"""Check that value has the right type."""
@@ -254,13 +257,12 @@ def _check_value(self, node, name_var, value_var):
254257
def getitem_slot(self, node, name_var):
255258
# A typed dict getitem should have a concrete string arg. If we have a var
256259
# with multiple bindings just fall back to Any.
257-
if not self._check_key(name_var):
258-
return node, self.ctx.new_unsolvable(node)
260+
self._check_key(name_var)
259261
return super().getitem_slot(node, name_var)
260262

261263
def setitem_slot(self, node, name_var, value_var):
262-
if self._check_key(name_var):
263-
self._check_value(node, name_var, value_var)
264+
self._check_key(name_var)
265+
self._check_value(node, name_var, value_var)
264266
return super().setitem_slot(node, name_var, value_var)
265267

266268
def set_str_item(self, node, name, value_var):

pytype/overlays/typing_overlay.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
Param = overlay_utils.Param
2424

2525

26+
def _is_typing_container(cls: pytd.Class):
27+
return pytd.IsContainer(cls) and cls.template
28+
29+
2630
class TypingOverlay(overlay.Overlay):
2731
"""A representation of the 'typing' module that allows custom overlays.
2832
@@ -39,7 +43,7 @@ def __init__(self, ctx):
3943
ast = ctx.loader.typing
4044
for cls in ast.classes:
4145
_, name = cls.name.rsplit(".", 1)
42-
if name not in member_map and pytd.IsContainer(cls) and cls.template:
46+
if name not in member_map and _is_typing_container(cls):
4347
member_map[name] = (overlay.build(name, TypingContainer), None)
4448
super().__init__(ctx, "typing", member_map, ast)
4549

@@ -70,7 +74,7 @@ def __init__(self, module_name, aliases, ctx):
7074
for pyval in ast.aliases + ast.classes + ast.constants + ast.functions:
7175
# Any public members that are not explicitly implemented are unsupported.
7276
_, name = pyval.name.rsplit(".", 1)
73-
if name.startswith("_"):
77+
if name.startswith("_") or name in member_map:
7478
continue
7579
if name in typing_overlay:
7680
member_map[name] = typing_overlay[name][0]
@@ -84,6 +88,9 @@ def __init__(self, module_name, aliases, ctx):
8488
def _build(name):
8589
def resolve(ctx):
8690
ast = ctx.loader.typing
91+
pytd_val = ast.Lookup(name)
92+
if isinstance(pytd_val, pytd.Class) and _is_typing_container(pytd_val):
93+
return TypingContainer(name.rsplit(".", 1)[-1], ctx)
8794
pytd_type = pytd.ToType(ast.Lookup(name), True, True, True)
8895
return ctx.convert.constant_to_value(pytd_type)
8996
return resolve

pytype/tests/test_collections_abc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,21 @@ def f() -> Callable[[], float]: ...
5353
x: float
5454
""")
5555

56+
def test_generator(self):
57+
self.Check("""
58+
from collections.abc import Generator
59+
def f() -> Generator[int, None, None]:
60+
yield 0
61+
""")
62+
63+
def test_set(self):
64+
# collections.abc.Set is an alias for typing.AbstractSet.
65+
self.Check("""
66+
from collections.abc import Set
67+
def f() -> Set[int]:
68+
return frozenset([0])
69+
""")
70+
5671

5772
if __name__ == "__main__":
5873
test_base.main()

pytype/tests/test_typed_dict.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ def f(x: Foo):
297297
f(x)
298298
""")
299299

300+
def test_key_existence_check(self):
301+
self.Check("""
302+
from typing import Union
303+
from typing_extensions import TypedDict
304+
305+
class Foo(TypedDict):
306+
a: int
307+
class Bar(TypedDict):
308+
b: str
309+
class Baz(TypedDict):
310+
c: Union[Foo, Bar]
311+
312+
baz: Baz = {'c': {'a': 0}}
313+
assert 'a' in baz['c']
314+
print(baz['c']['a'])
315+
""")
316+
300317

301318
_SINGLE = """
302319
from typing import TypedDict

pytype/vm_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,6 @@ def call_binary_operator(state, name, x, y, report_errors, ctx):
736736
if ctx.options.report_errors:
737737
ctx.errorlog.unsupported_operands(ctx.vm.frames, name, x, y)
738738
result = ctx.new_unsolvable(state.node)
739-
elif isinstance(error, function.DictKeyMissing):
740-
state, result = error.get_return(state)
741739
else:
742740
if ctx.options.report_errors:
743741
ctx.errorlog.invalid_function_call(ctx.vm.frames, error)

0 commit comments

Comments
 (0)