@@ -496,7 +496,9 @@ def add_import(self, module: str, alias: str | None = None) -> None:
496
496
name = name .rpartition ("." )[0 ]
497
497
498
498
def require_name (self , name : str ) -> None :
499
- self .required_names .add (name .split ("." )[0 ])
499
+ while name not in self .direct_imports and "." in name :
500
+ name = name .rsplit ("." , 1 )[0 ]
501
+ self .required_names .add (name )
500
502
501
503
def reexport (self , name : str ) -> None :
502
504
"""Mark a given non qualified name as needed in __all__.
@@ -516,7 +518,10 @@ def import_lines(self) -> list[str]:
516
518
# be imported from it. the names can also be alias in the form 'original as alias'
517
519
module_map : Mapping [str , list [str ]] = defaultdict (list )
518
520
519
- for name in sorted (self .required_names ):
521
+ for name in sorted (
522
+ self .required_names ,
523
+ key = lambda n : (self .reverse_alias [n ], n ) if n in self .reverse_alias else (n , "" ),
524
+ ):
520
525
# If we haven't seen this name in an import statement, ignore it
521
526
if name not in self .module_for :
522
527
continue
@@ -540,7 +545,7 @@ def import_lines(self) -> list[str]:
540
545
assert "." not in name # Because reexports only has nonqualified names
541
546
result .append (f"import { name } as { name } \n " )
542
547
else :
543
- result .append (f"import { self . direct_imports [ name ] } \n " )
548
+ result .append (f"import { name } \n " )
544
549
545
550
# Now generate all the from ... import ... lines collected in module_map
546
551
for module , names in sorted (module_map .items ()):
@@ -595,7 +600,7 @@ def visit_name_expr(self, e: NameExpr) -> None:
595
600
self .refs .add (e .name )
596
601
597
602
def visit_instance (self , t : Instance ) -> None :
598
- self .add_ref (t .type .fullname )
603
+ self .add_ref (t .type .name )
599
604
super ().visit_instance (t )
600
605
601
606
def visit_unbound_type (self , t : UnboundType ) -> None :
@@ -614,7 +619,10 @@ def visit_callable_type(self, t: CallableType) -> None:
614
619
t .ret_type .accept (self )
615
620
616
621
def add_ref (self , fullname : str ) -> None :
617
- self .refs .add (fullname .split ("." )[- 1 ])
622
+ self .refs .add (fullname )
623
+ while "." in fullname :
624
+ fullname = fullname .rsplit ("." , 1 )[0 ]
625
+ self .refs .add (fullname )
618
626
619
627
620
628
class StubGenerator (mypy .traverser .TraverserVisitor ):
@@ -1295,6 +1303,7 @@ def visit_import_from(self, o: ImportFrom) -> None:
1295
1303
if (
1296
1304
as_name is None
1297
1305
and name not in self .referenced_names
1306
+ and not any (n .startswith (name + "." ) for n in self .referenced_names )
1298
1307
and (not self ._all_ or name in IGNORED_DUNDERS )
1299
1308
and not is_private
1300
1309
and module not in ("abc" , "asyncio" ) + TYPING_MODULE_NAMES
@@ -1303,14 +1312,15 @@ def visit_import_from(self, o: ImportFrom) -> None:
1303
1312
# exported, unless there is an explicit __all__. Note that we need to special
1304
1313
# case 'abc' since some references are deleted during semantic analysis.
1305
1314
exported = True
1306
- top_level = full_module .split ("." )[0 ]
1315
+ top_level = full_module .split ("." , 1 )[0 ]
1316
+ self_top_level = self .module .split ("." , 1 )[0 ]
1307
1317
if (
1308
1318
as_name is None
1309
1319
and not self .export_less
1310
1320
and (not self ._all_ or name in IGNORED_DUNDERS )
1311
1321
and self .module
1312
1322
and not is_private
1313
- and top_level in (self . module . split ( "." )[ 0 ] , "_" + self . module . split ( "." )[ 0 ] )
1323
+ and top_level in (self_top_level , "_" + self_top_level )
1314
1324
):
1315
1325
# Export imports from the same package, since we can't reliably tell whether they
1316
1326
# are part of the public API.
0 commit comments