Skip to content

Commit 0bf711f

Browse files
timmartincdce8p
andauthored
Infer returned value of .copy() method for collections (#1540)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
1 parent a39bbce commit 0bf711f

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ Release date: TBA
2020
* Rename ``ModuleSpec`` -> ``module_type`` constructor parameter to match attribute
2121
name and improve typing. Use ``type`` instead.
2222

23+
* Infer the return value of the ``.copy()`` method on ``dict``, ``list``, ``set``,
24+
and ``frozenset``.
25+
26+
Closes #1403
2327

2428
What's New in astroid 2.11.6?
2529
=============================

astroid/brain/brain_builtin_inference.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
"""Astroid hooks for various builtins."""
66

7+
import itertools
78
from functools import partial
8-
from typing import Optional
9+
from typing import Iterator, Optional
910

1011
from astroid import arguments, helpers, inference_tip, nodes, objects, util
1112
from astroid.builder import AstroidBuilder
@@ -892,6 +893,22 @@ def _build_dict_with_elements(elements):
892893
return _build_dict_with_elements([])
893894

894895

896+
def _infer_copy_method(
897+
node: nodes.Call, context: Optional[InferenceContext] = None
898+
) -> Iterator[nodes.NodeNG]:
899+
assert isinstance(node.func, nodes.Attribute)
900+
inferred_orig, inferred_copy = itertools.tee(node.func.expr.infer(context=context))
901+
if all(
902+
isinstance(
903+
inferred_node, (nodes.Dict, nodes.List, nodes.Set, objects.FrozenSet)
904+
)
905+
for inferred_node in inferred_orig
906+
):
907+
return inferred_copy
908+
909+
raise UseInferenceDefault()
910+
911+
895912
# Builtins inference
896913
register_builtin_transform(infer_bool, "bool")
897914
register_builtin_transform(infer_super, "super")
@@ -920,3 +937,10 @@ def _build_dict_with_elements(elements):
920937
inference_tip(_infer_object__new__decorator),
921938
_infer_object__new__decorator_check,
922939
)
940+
941+
AstroidManager().register_transform(
942+
nodes.Call,
943+
inference_tip(_infer_copy_method),
944+
lambda node: isinstance(node.func, nodes.Attribute)
945+
and node.func.attrname == "copy",
946+
)

tests/unittest_inference.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,6 +2051,37 @@ def test_dict_invalid_args(self) -> None:
20512051
self.assertIsInstance(inferred, Instance)
20522052
self.assertEqual(inferred.qname(), "builtins.dict")
20532053

2054+
def test_copy_method_inference(self) -> None:
2055+
code = """
2056+
a_dict = {"b": 1, "c": 2}
2057+
b_dict = a_dict.copy()
2058+
b_dict #@
2059+
2060+
a_list = [1, 2, 3]
2061+
b_list = a_list.copy()
2062+
b_list #@
2063+
2064+
a_set = set([1, 2, 3])
2065+
b_set = a_set.copy()
2066+
b_set #@
2067+
2068+
a_frozenset = frozenset([1, 2, 3])
2069+
b_frozenset = a_frozenset.copy()
2070+
b_frozenset #@
2071+
2072+
a_unknown = unknown()
2073+
b_unknown = a_unknown.copy()
2074+
b_unknown #@
2075+
"""
2076+
ast = extract_node(code, __name__)
2077+
self.assertInferDict(ast[0], {"b": 1, "c": 2})
2078+
self.assertInferList(ast[1], [1, 2, 3])
2079+
self.assertInferSet(ast[2], [1, 2, 3])
2080+
self.assertInferFrozenSet(ast[3], [1, 2, 3])
2081+
2082+
inferred_unknown = next(ast[4].infer())
2083+
assert inferred_unknown == util.Uninferable
2084+
20542085
def test_str_methods(self) -> None:
20552086
code = """
20562087
' '.decode() #@

0 commit comments

Comments
 (0)