Skip to content

Commit a2c5f97

Browse files
authored
Merge pull request #1199 from google/google_sync
Google sync
2 parents 5710b85 + 855735e commit a2c5f97

18 files changed

+125
-32
lines changed

pytype/abstract/_function_base.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,48 @@ def match_args(self, node, args, alias_map=None, match_all_views=False):
7070
name = self._get_cell_variable_name(a)
7171
assert name is not None, "Closure variable lookup failed."
7272
raise function.UndefinedParameterError(name)
73+
# The implementation of match_args is currently rather convoluted because we
74+
# have two different implementations:
75+
# * Old implementation: `_match_views` matches 'args' against 'self' one
76+
# view at a time, where a view is a mapping of every variable in args to a
77+
# particular binding. This handles generics but scales poorly with the
78+
# number of bindings per variable.
79+
# * New implementation: `_match_args_sequentially` matches 'args' one at a
80+
# time. This scales better but cannot yet handle generics.
81+
# Subclasses should implement the following:
82+
# * _match_view(node, args, view, alias_map): this will be called repeatedly
83+
# by _match_views.
84+
# * _match_args_sequentially(node, args, alias_map, match_all_views): A
85+
# sequential matching implementation.
86+
# TODO(b/228241343): Get rid of _match_views and simplify match_args once
87+
# _match_args_sequentially can handle generics.
88+
if self._is_generic_call(args):
89+
return self._match_views(node, args, alias_map, match_all_views)
90+
return self._match_args_sequentially(node, args, alias_map, match_all_views)
91+
92+
def _is_generic_call(self, args):
93+
for sig in function.get_signatures(self):
94+
for t in sig.annotations.values():
95+
stack = [t]
96+
seen = set()
97+
while stack:
98+
cur = stack.pop()
99+
if cur in seen:
100+
continue
101+
seen.add(cur)
102+
if cur.formal or cur.template:
103+
return True
104+
if _isinstance(cur, "Union"):
105+
stack.extend(cur.options)
106+
if self.is_attribute_of_class and args.posargs:
107+
for self_val in args.posargs[0].data:
108+
for cls in self_val.cls.mro:
109+
if cls.template:
110+
return True
111+
return False
112+
113+
def _match_views(self, node, args, alias_map, match_all_views):
114+
"""Matches all views of the given args against this function."""
73115
error = None
74116
matched = []
75117
arg_variables = args.get_variables()
@@ -107,6 +149,9 @@ def match_args(self, node, args, alias_map=None, match_all_views=False):
107149
def _match_view(self, node, args, view, alias_map):
108150
raise NotImplementedError(self.__class__.__name__)
109151

152+
def _match_args_sequentially(self, node, args, alias_map, match_all_views):
153+
raise NotImplementedError(self.__class__.__name__)
154+
110155
def __repr__(self):
111156
return self.full_name + "(...)"
112157

pytype/abstract/_instance_base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytype.abstract import function
1010

1111
log = logging.getLogger(__name__)
12+
_isinstance = abstract_utils._isinstance # pylint: disable=protected-access
1213

1314

1415
class SimpleValue(_base.BaseValue):
@@ -99,15 +100,18 @@ def merge_instance_type_parameter(self, node, name, value):
99100
else:
100101
self.instance_type_parameters[name] = value
101102

102-
def call(self, node, func, args, alias_map=None):
103-
binding = func if self == func.data else self.to_binding(node)
103+
def _call_helper(self, node, obj, binding, args):
104+
obj_binding = binding if obj == binding.data else obj.to_binding(node)
104105
node, var = self.ctx.attribute_handler.get_attribute(
105-
node, self, "__call__", binding)
106+
node, obj, "__call__", obj_binding)
106107
if var is not None and var.bindings:
107108
return function.call_function(self.ctx, node, var, args)
108109
else:
109110
raise function.NotCallable(self)
110111

112+
def call(self, node, func, args, alias_map=None):
113+
return self._call_helper(node, self, func, args)
114+
111115
def argcount(self, node):
112116
node, var = self.ctx.attribute_handler.get_attribute(
113117
node, self, "__call__", self.to_binding(node))

pytype/abstract/_interpreter_function.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,30 @@ def _match_view(self, node, args, view, alias_map=None):
234234
self.signature, args, self.ctx, bad_param=bad_arg)
235235
return subst
236236

237+
def _match_args_sequentially(self, node, args, alias_map, match_all_views):
238+
def match_succeeded(match_result):
239+
bad_matches, any_match = match_result
240+
if not bad_matches:
241+
return True
242+
if match_all_views or self.ctx.options.strict_parameter_checks:
243+
return False
244+
return any_match
245+
246+
for name, arg, formal in self.signature.iter_args(args):
247+
if formal is None:
248+
continue
249+
if name in (self.signature.varargs_name, self.signature.kwargs_name):
250+
# The annotation is Tuple or Dict, but the passed arg only has to be
251+
# Iterable or Mapping.
252+
formal = self.ctx.convert.widen_type(formal)
253+
match_result = self.ctx.matcher(node).bad_matches(arg, formal)
254+
if not match_succeeded(match_result):
255+
bad_arg = function.BadParam(
256+
name=name, expected=formal, error_details=match_result[0][0][1])
257+
raise function.WrongArgTypes(
258+
self.signature, args, self.ctx, bad_param=bad_arg)
259+
return [{}]
260+
237261
def get_first_opcode(self):
238262
return None
239263

pytype/abstract/_pytd_function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,9 @@ def _yield_matching_signatures(self, node, args, view, alias_map):
396396
if not matched:
397397
raise error # pylint: disable=raising-bad-type
398398

399+
def _match_args_sequentially(self, node, args, alias_map, match_all_views):
400+
return self._match_views(node, args, alias_map, match_all_views)
401+
399402
def set_function_defaults(self, unused_node, defaults_var):
400403
"""Attempts to set default arguments for a function's signatures.
401404

pytype/abstract/_typing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def __init__(self, name, ctx, base_cls):
7575
super().__init__(name, ctx)
7676
self.base_cls = base_cls
7777

78+
def __repr__(self):
79+
return "AnnotationContainer(%s)" % self.name
80+
7881
def _sub_annotation(
7982
self, annot: _base.BaseValue, subst: Mapping[str, _base.BaseValue]
8083
) -> _base.BaseValue:
@@ -277,7 +280,7 @@ def _build_value(self, node, inner, ellipses):
277280
root_node, container=abstract_utils.DUMMY_CONTAINER)
278281
else:
279282
actual = param_value.instantiate(root_node)
280-
bad = self.ctx.matcher(root_node).bad_matches(actual, formal_param)
283+
bad, _ = self.ctx.matcher(root_node).bad_matches(actual, formal_param)
281284
if bad:
282285
if not isinstance(param_value, TypeParameter):
283286
# If param_value is not a TypeVar, we substitute in TypeVar bounds
@@ -300,6 +303,9 @@ def _build_value(self, node, inner, ellipses):
300303
self.ctx.errorlog.invalid_annotation(self.ctx.vm.frames, e.annot, e.error)
301304
return self.ctx.convert.unsolvable
302305

306+
def call(self, node, func, args, alias_map=None):
307+
return self._call_helper(node, self.base_cls, func, args)
308+
303309

304310
class TypeParameter(_base.BaseValue):
305311
"""Parameter of a type."""

pytype/abstract/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_signatures(func):
3030
return [sig.drop_first_parameter() for sig in sigs] # drop "self"
3131
elif _isinstance(func, ("ClassMethod", "StaticMethod")):
3232
return get_signatures(func.method)
33-
elif _isinstance(func, "SimpleFunction"):
33+
elif _isinstance(func, "SignedFunction"):
3434
return [func.signature]
3535
elif _isinstance(func.cls, "CallableClass"):
3636
return [Signature.from_callable(func.cls)]

pytype/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def check_annotation_type_mismatch(
164164
typ, ("typing.ClassVar", "dataclasses.InitVar"))
165165
if contained_type:
166166
typ = contained_type
167-
bad = self.matcher(node).bad_matches(value, typ)
167+
bad, _ = self.matcher(node).bad_matches(value, typ)
168168
for view, error_details in bad:
169169
binding = view[value]
170170
self.errorlog.annotation_type_mismatch(

pytype/load_pytd.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,9 @@ def log_module_not_found(self, module_name):
286286
class _Resolver:
287287
"""Resolve symbols in a pytd tree."""
288288

289-
def __init__(self, builtins_ast, enable_nested_classes):
289+
def __init__(self, builtins_ast):
290290
self.builtins_ast = builtins_ast
291291
self.allow_singletons = False
292-
self._enable_nested_classes = enable_nested_classes
293292

294293
def _lookup(self, visitor, mod_ast, lookup_ast):
295294
if lookup_ast:
@@ -298,9 +297,7 @@ def _lookup(self, visitor, mod_ast, lookup_ast):
298297
return mod_ast
299298

300299
def resolve_local_types(self, mod_ast, *, lookup_ast=None):
301-
local_lookup = visitors.LookupLocalTypes(
302-
self.allow_singletons,
303-
enable_nested_classes=self._enable_nested_classes)
300+
local_lookup = visitors.LookupLocalTypes(self.allow_singletons)
304301
return self._lookup(local_lookup, mod_ast, lookup_ast)
305302

306303
def resolve_builtin_types(self, mod_ast, *, lookup_ast=None):
@@ -376,7 +373,7 @@ def __init__(self, options, modules=None):
376373
self._path_finder = _PathFinder(options)
377374
self._builtin_loader = builtin_stubs.BuiltinLoader(
378375
parser.PyiOptions.from_toplevel_options(options))
379-
self._resolver = _Resolver(self.builtins, options.enable_nested_classes)
376+
self._resolver = _Resolver(self.builtins)
380377
self._import_name_cache = {} # performance cache
381378
self._aliases = {}
382379
self._prefixes = set()

pytype/matcher.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,21 @@ def bad_matches(self, var, other_type):
175175
var: A cfg.Variable, containing instances.
176176
other_type: An instance of BaseValue.
177177
Returns:
178-
A list of all the views of var that didn't match.
178+
A pair of:
179+
* A list of all the views of var that didn't match.
180+
* Whether at least one view matched.
181+
TODO(b/63407497): We should be able to get rid of the second value once we
182+
start requiring that all views match.
179183
"""
180184
bad = []
181185
if (var.data == [self.ctx.convert.unsolvable] or
182186
other_type == self.ctx.convert.unsolvable):
183187
# An unsolvable matches everything. Since bad_matches doesn't need to
184188
# compute substitutions, we can return immediately.
185-
return bad
189+
return bad, True
186190
views = abstract_utils.get_views([var], self._node)
187191
skip_future = None
192+
any_match = False
188193
while True:
189194
try:
190195
view = views.send(skip_future)
@@ -200,7 +205,8 @@ def bad_matches(self, var, other_type):
200205
skip_future = False
201206
else:
202207
skip_future = True
203-
return bad
208+
any_match = True
209+
return bad, any_match
204210

205211
def match_from_mro(self, left, other_type, allow_compat_builtins=True):
206212
"""Checks a type's MRO for a match for a formal type.
@@ -989,7 +995,7 @@ def _match_dict_against_typed_dict(self, left, other_type):
989995
if k not in fields:
990996
continue
991997
typ = abstract_utils.get_atomic_value(fields[k])
992-
b = self.bad_matches(v, typ)
998+
b, _ = self.bad_matches(v, typ)
993999
if b:
9941000
bad.append((k, v, typ, b))
9951001
if missing or extra or bad:

pytype/matcher_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,14 +508,14 @@ def test_bad_matches(self):
508508
self.matcher.bad_matches(
509509
abstract.TypeParameter("T",
510510
self.ctx).to_variable(self.ctx.root_node),
511-
self.ctx.convert.unsolvable))
511+
self.ctx.convert.unsolvable)[0])
512512

513513
def test_bad_matches_no_match(self):
514514
self.assertTrue(
515515
self.matcher.bad_matches(
516516
abstract.TypeParameter("T",
517517
self.ctx).to_variable(self.ctx.root_node),
518-
self.ctx.convert.int_type))
518+
self.ctx.convert.int_type)[0])
519519

520520
def test_any(self):
521521
self.assertMatch(

0 commit comments

Comments
 (0)