Skip to content

Commit 0c8b761

Browse files
authored
stubgen: multiple fixes to the generated imports (#15624)
* Fix handling of nested imports. Instead of assuming that a name is imported from a top level package, look in the imports for this name starting from the parent submodule up until the import is found * Fix "from imports" getting reexported unnecessarily * Fix import sorting when having import aliases Fixes #13661 Fixes #7006
1 parent 9edda9a commit 0c8b761

File tree

2 files changed

+74
-10
lines changed

2 files changed

+74
-10
lines changed

mypy/stubgen.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,9 @@ def add_import(self, module: str, alias: str | None = None) -> None:
496496
name = name.rpartition(".")[0]
497497

498498
def require_name(self, name: str) -> None:
499-
self.required_names.add(name.split(".")[0])
499+
while name not in self.direct_imports and "." in name:
500+
name = name.rsplit(".", 1)[0]
501+
self.required_names.add(name)
500502

501503
def reexport(self, name: str) -> None:
502504
"""Mark a given non qualified name as needed in __all__.
@@ -516,7 +518,10 @@ def import_lines(self) -> list[str]:
516518
# be imported from it. the names can also be alias in the form 'original as alias'
517519
module_map: Mapping[str, list[str]] = defaultdict(list)
518520

519-
for name in sorted(self.required_names):
521+
for name in sorted(
522+
self.required_names,
523+
key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""),
524+
):
520525
# If we haven't seen this name in an import statement, ignore it
521526
if name not in self.module_for:
522527
continue
@@ -540,7 +545,7 @@ def import_lines(self) -> list[str]:
540545
assert "." not in name # Because reexports only has nonqualified names
541546
result.append(f"import {name} as {name}\n")
542547
else:
543-
result.append(f"import {self.direct_imports[name]}\n")
548+
result.append(f"import {name}\n")
544549

545550
# Now generate all the from ... import ... lines collected in module_map
546551
for module, names in sorted(module_map.items()):
@@ -595,7 +600,7 @@ def visit_name_expr(self, e: NameExpr) -> None:
595600
self.refs.add(e.name)
596601

597602
def visit_instance(self, t: Instance) -> None:
598-
self.add_ref(t.type.fullname)
603+
self.add_ref(t.type.name)
599604
super().visit_instance(t)
600605

601606
def visit_unbound_type(self, t: UnboundType) -> None:
@@ -614,7 +619,10 @@ def visit_callable_type(self, t: CallableType) -> None:
614619
t.ret_type.accept(self)
615620

616621
def add_ref(self, fullname: str) -> None:
617-
self.refs.add(fullname.split(".")[-1])
622+
self.refs.add(fullname)
623+
while "." in fullname:
624+
fullname = fullname.rsplit(".", 1)[0]
625+
self.refs.add(fullname)
618626

619627

620628
class StubGenerator(mypy.traverser.TraverserVisitor):
@@ -1295,6 +1303,7 @@ def visit_import_from(self, o: ImportFrom) -> None:
12951303
if (
12961304
as_name is None
12971305
and name not in self.referenced_names
1306+
and not any(n.startswith(name + ".") for n in self.referenced_names)
12981307
and (not self._all_ or name in IGNORED_DUNDERS)
12991308
and not is_private
13001309
and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES
@@ -1303,14 +1312,15 @@ def visit_import_from(self, o: ImportFrom) -> None:
13031312
# exported, unless there is an explicit __all__. Note that we need to special
13041313
# case 'abc' since some references are deleted during semantic analysis.
13051314
exported = True
1306-
top_level = full_module.split(".")[0]
1315+
top_level = full_module.split(".", 1)[0]
1316+
self_top_level = self.module.split(".", 1)[0]
13071317
if (
13081318
as_name is None
13091319
and not self.export_less
13101320
and (not self._all_ or name in IGNORED_DUNDERS)
13111321
and self.module
13121322
and not is_private
1313-
and top_level in (self.module.split(".")[0], "_" + self.module.split(".")[0])
1323+
and top_level in (self_top_level, "_" + self_top_level)
13141324
):
13151325
# Export imports from the same package, since we can't reliably tell whether they
13161326
# are part of the public API.

test-data/unit/stubgen.test

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,9 +2772,9 @@ y: b.Y
27722772
z: p.a.X
27732773

27742774
[out]
2775+
import p.a
27752776
import p.a as a
27762777
import p.b as b
2777-
import p.a
27782778

27792779
x: a.X
27802780
y: b.Y
@@ -2787,7 +2787,7 @@ from p import a
27872787
x: a.X
27882788

27892789
[out]
2790-
from p import a as a
2790+
from p import a
27912791

27922792
x: a.X
27932793

@@ -2809,7 +2809,7 @@ from p import a
28092809
x: a.X
28102810

28112811
[out]
2812-
from p import a as a
2812+
from p import a
28132813

28142814
x: a.X
28152815

@@ -2859,6 +2859,60 @@ import p.a
28592859
x: a.X
28602860
y: p.a.Y
28612861

2862+
[case testNestedImports]
2863+
import p
2864+
import p.m1
2865+
import p.m2
2866+
2867+
x: p.X
2868+
y: p.m1.Y
2869+
z: p.m2.Z
2870+
2871+
[out]
2872+
import p
2873+
import p.m1
2874+
import p.m2
2875+
2876+
x: p.X
2877+
y: p.m1.Y
2878+
z: p.m2.Z
2879+
2880+
[case testNestedImportsAliased]
2881+
import p as t
2882+
import p.m1 as pm1
2883+
import p.m2 as pm2
2884+
2885+
x: t.X
2886+
y: pm1.Y
2887+
z: pm2.Z
2888+
2889+
[out]
2890+
import p as t
2891+
import p.m1 as pm1
2892+
import p.m2 as pm2
2893+
2894+
x: t.X
2895+
y: pm1.Y
2896+
z: pm2.Z
2897+
2898+
[case testNestedFromImports]
2899+
from p import m1
2900+
from p.m1 import sm1
2901+
from p.m2 import sm2
2902+
2903+
x: m1.X
2904+
y: sm1.Y
2905+
z: sm2.Z
2906+
2907+
[out]
2908+
from p import m1
2909+
from p.m1 import sm1
2910+
from p.m2 import sm2
2911+
2912+
x: m1.X
2913+
y: sm1.Y
2914+
z: sm2.Z
2915+
28622916
[case testOverload_fromTypingImport]
28632917
from typing import Tuple, Union, overload
28642918

0 commit comments

Comments
 (0)