Skip to content

Commit 855735e

Browse files
committed
Add code structure for gradual migration to argument-by-argument call matching.
This adds some code that will let our current function call matching algorithm and a new, faster one exist side-by-side until the new one can handle all the cases that the current one does. It also enables the new matching for trivial cases (when we're matching against an InterpreterFunction with no TypeVars in the function signature, the containing class (for a method), or the function arguments). This is already enough to show a pretty impressive speedup in our function call benchmark but not in the real-world use case I tested. If anything, it's a tad slower because of the _is_generic_call check, but that'll go away eventually. PiperOrigin-RevId: 446044514
1 parent 2e78dd1 commit 855735e

File tree

11 files changed

+90
-12
lines changed

11 files changed

+90
-12
lines changed

pytype/abstract/_function_base.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,48 @@ def match_args(self, node, args, alias_map=None, match_all_views=False):
7070
name = self._get_cell_variable_name(a)
7171
assert name is not None, "Closure variable lookup failed."
7272
raise function.UndefinedParameterError(name)
73+
# The implementation of match_args is currently rather convoluted because we
74+
# have two different implementations:
75+
# * Old implementation: `_match_views` matches 'args' against 'self' one
76+
# view at a time, where a view is a mapping of every variable in args to a
77+
# particular binding. This handles generics but scales poorly with the
78+
# number of bindings per variable.
79+
# * New implementation: `_match_args_sequentially` matches 'args' one at a
80+
# time. This scales better but cannot yet handle generics.
81+
# Subclasses should implement the following:
82+
# * _match_view(node, args, view, alias_map): this will be called repeatedly
83+
# by _match_views.
84+
# * _match_args_sequentially(node, args, alias_map, match_all_views): A
85+
# sequential matching implementation.
86+
# TODO(b/228241343): Get rid of _match_views and simplify match_args once
87+
# _match_args_sequentially can handle generics.
88+
if self._is_generic_call(args):
89+
return self._match_views(node, args, alias_map, match_all_views)
90+
return self._match_args_sequentially(node, args, alias_map, match_all_views)
91+
92+
def _is_generic_call(self, args):
93+
for sig in function.get_signatures(self):
94+
for t in sig.annotations.values():
95+
stack = [t]
96+
seen = set()
97+
while stack:
98+
cur = stack.pop()
99+
if cur in seen:
100+
continue
101+
seen.add(cur)
102+
if cur.formal or cur.template:
103+
return True
104+
if _isinstance(cur, "Union"):
105+
stack.extend(cur.options)
106+
if self.is_attribute_of_class and args.posargs:
107+
for self_val in args.posargs[0].data:
108+
for cls in self_val.cls.mro:
109+
if cls.template:
110+
return True
111+
return False
112+
113+
def _match_views(self, node, args, alias_map, match_all_views):
114+
"""Matches all views of the given args against this function."""
73115
error = None
74116
matched = []
75117
arg_variables = args.get_variables()
@@ -107,6 +149,9 @@ def match_args(self, node, args, alias_map=None, match_all_views=False):
107149
def _match_view(self, node, args, view, alias_map):
108150
raise NotImplementedError(self.__class__.__name__)
109151

152+
def _match_args_sequentially(self, node, args, alias_map, match_all_views):
153+
raise NotImplementedError(self.__class__.__name__)
154+
110155
def __repr__(self):
111156
return self.full_name + "(...)"
112157

pytype/abstract/_interpreter_function.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,30 @@ def _match_view(self, node, args, view, alias_map=None):
234234
self.signature, args, self.ctx, bad_param=bad_arg)
235235
return subst
236236

237+
def _match_args_sequentially(self, node, args, alias_map, match_all_views):
238+
def match_succeeded(match_result):
239+
bad_matches, any_match = match_result
240+
if not bad_matches:
241+
return True
242+
if match_all_views or self.ctx.options.strict_parameter_checks:
243+
return False
244+
return any_match
245+
246+
for name, arg, formal in self.signature.iter_args(args):
247+
if formal is None:
248+
continue
249+
if name in (self.signature.varargs_name, self.signature.kwargs_name):
250+
# The annotation is Tuple or Dict, but the passed arg only has to be
251+
# Iterable or Mapping.
252+
formal = self.ctx.convert.widen_type(formal)
253+
match_result = self.ctx.matcher(node).bad_matches(arg, formal)
254+
if not match_succeeded(match_result):
255+
bad_arg = function.BadParam(
256+
name=name, expected=formal, error_details=match_result[0][0][1])
257+
raise function.WrongArgTypes(
258+
self.signature, args, self.ctx, bad_param=bad_arg)
259+
return [{}]
260+
237261
def get_first_opcode(self):
238262
return None
239263

pytype/abstract/_pytd_function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,9 @@ def _yield_matching_signatures(self, node, args, view, alias_map):
396396
if not matched:
397397
raise error # pylint: disable=raising-bad-type
398398

399+
def _match_args_sequentially(self, node, args, alias_map, match_all_views):
400+
return self._match_views(node, args, alias_map, match_all_views)
401+
399402
def set_function_defaults(self, unused_node, defaults_var):
400403
"""Attempts to set default arguments for a function's signatures.
401404

pytype/abstract/_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _build_value(self, node, inner, ellipses):
280280
root_node, container=abstract_utils.DUMMY_CONTAINER)
281281
else:
282282
actual = param_value.instantiate(root_node)
283-
bad = self.ctx.matcher(root_node).bad_matches(actual, formal_param)
283+
bad, _ = self.ctx.matcher(root_node).bad_matches(actual, formal_param)
284284
if bad:
285285
if not isinstance(param_value, TypeParameter):
286286
# If param_value is not a TypeVar, we substitute in TypeVar bounds

pytype/abstract/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_signatures(func):
3030
return [sig.drop_first_parameter() for sig in sigs] # drop "self"
3131
elif _isinstance(func, ("ClassMethod", "StaticMethod")):
3232
return get_signatures(func.method)
33-
elif _isinstance(func, "SimpleFunction"):
33+
elif _isinstance(func, "SignedFunction"):
3434
return [func.signature]
3535
elif _isinstance(func.cls, "CallableClass"):
3636
return [Signature.from_callable(func.cls)]

pytype/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def check_annotation_type_mismatch(
164164
typ, ("typing.ClassVar", "dataclasses.InitVar"))
165165
if contained_type:
166166
typ = contained_type
167-
bad = self.matcher(node).bad_matches(value, typ)
167+
bad, _ = self.matcher(node).bad_matches(value, typ)
168168
for view, error_details in bad:
169169
binding = view[value]
170170
self.errorlog.annotation_type_mismatch(

pytype/matcher.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,21 @@ def bad_matches(self, var, other_type):
175175
var: A cfg.Variable, containing instances.
176176
other_type: An instance of BaseValue.
177177
Returns:
178-
A list of all the views of var that didn't match.
178+
A pair of:
179+
* A list of all the views of var that didn't match.
180+
* Whether at least one view matched.
181+
TODO(b/63407497): We should be able to get rid of the second value once we
182+
start requiring that all views match.
179183
"""
180184
bad = []
181185
if (var.data == [self.ctx.convert.unsolvable] or
182186
other_type == self.ctx.convert.unsolvable):
183187
# An unsolvable matches everything. Since bad_matches doesn't need to
184188
# compute substitutions, we can return immediately.
185-
return bad
189+
return bad, True
186190
views = abstract_utils.get_views([var], self._node)
187191
skip_future = None
192+
any_match = False
188193
while True:
189194
try:
190195
view = views.send(skip_future)
@@ -200,7 +205,8 @@ def bad_matches(self, var, other_type):
200205
skip_future = False
201206
else:
202207
skip_future = True
203-
return bad
208+
any_match = True
209+
return bad, any_match
204210

205211
def match_from_mro(self, left, other_type, allow_compat_builtins=True):
206212
"""Checks a type's MRO for a match for a formal type.
@@ -989,7 +995,7 @@ def _match_dict_against_typed_dict(self, left, other_type):
989995
if k not in fields:
990996
continue
991997
typ = abstract_utils.get_atomic_value(fields[k])
992-
b = self.bad_matches(v, typ)
998+
b, _ = self.bad_matches(v, typ)
993999
if b:
9941000
bad.append((k, v, typ, b))
9951001
if missing or extra or bad:

pytype/matcher_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,14 +508,14 @@ def test_bad_matches(self):
508508
self.matcher.bad_matches(
509509
abstract.TypeParameter("T",
510510
self.ctx).to_variable(self.ctx.root_node),
511-
self.ctx.convert.unsolvable))
511+
self.ctx.convert.unsolvable)[0])
512512

513513
def test_bad_matches_no_match(self):
514514
self.assertTrue(
515515
self.matcher.bad_matches(
516516
abstract.TypeParameter("T",
517517
self.ctx).to_variable(self.ctx.root_node),
518-
self.ctx.convert.int_type))
518+
self.ctx.convert.int_type)[0])
519519

520520
def test_any(self):
521521
self.assertMatch(

pytype/overlays/typed_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _check_str_key_value(self, node, name, value_var):
228228
if not self._check_str_key(name):
229229
return
230230
typ = abstract_utils.get_atomic_value(self.fields[name])
231-
bad = self.ctx.matcher(node).bad_matches(value_var, typ)
231+
bad, _ = self.ctx.matcher(node).bad_matches(value_var, typ)
232232
for view, error_details in bad:
233233
binding = view[value_var]
234234
self.ctx.errorlog.annotation_type_mismatch(

pytype/overriding_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def _check_signature_compatible(method_signature, base_signature,
417417
def is_subtype(this_type, that_type):
418418
"""Return True iff this_type is a subclass of that_type."""
419419
this_type_instance = this_type.instantiate(ctx.root_node, None)
420-
return not matcher.bad_matches(this_type_instance, that_type)
420+
return not matcher.bad_matches(this_type_instance, that_type)[0]
421421

422422
check_result = (
423423
_check_positional_parameters(method_signature, base_signature, is_subtype)

0 commit comments

Comments
 (0)