Skip to content

Commit c015036

Browse files
kinto0facebook-github-bot
authored andcommitted
represent imports better
Summary: by using `Ast::imports` to find imports, we considered the entire import stmt including keywords in what we consider a "module". But in reality, we should only consider the names*. This diff makes a `ImportIdentifier` enum, walking the statements and grabbing the necessary items. For LSP features like completion, I believe the x in `from x import y` is very similar to the the x in `import x`. For that reason I make both of these `x`s `ImportIdentifier::Module`. But the names are different and get their own enum variant (to be used in next diff). Reviewed By: SamChou19815 Differential Revision: D75473401 fbshipit-source-id: ddcf1b5da352e7bb9ab59e046784f09587f1cd02
1 parent 3891e38 commit c015036

File tree

3 files changed

+91
-14
lines changed

3 files changed

+91
-14
lines changed

pyrefly/lib/state/lsp.rs

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ pub enum DefinitionMetadata {
6060
VariableOrAttribute(Name),
6161
}
6262

63+
enum ImportIdentifier {
64+
// The name of a module. ex: `x` in `import x` or `from x import name`
65+
Module(ModuleName),
66+
// A name from a module's exports. ex: `name` in `from x import name`
67+
// Note: these are also definitions
68+
Name(ModuleName),
69+
}
70+
71+
impl ImportIdentifier {
72+
fn module_name(&self) -> ModuleName {
73+
match self {
74+
ImportIdentifier::Module(module_name) => *module_name,
75+
ImportIdentifier::Name(module_name) => *module_name,
76+
}
77+
}
78+
}
79+
6380
impl<'a> Transaction<'a> {
6481
fn get_type(&self, handle: &Handle, key: &Key) -> Option<Type> {
6582
let idx = self.get_bindings(handle)?.key_to_idx(key);
@@ -88,15 +105,43 @@ impl<'a> Transaction<'a> {
88105
res
89106
}
90107

91-
fn import_at(&self, handle: &Handle, position: TextSize) -> Option<ModuleName> {
92-
let module = self.get_ast(handle)?;
93-
for (module, text_range) in Ast::imports(&module, handle.module(), handle.path().is_init())
94-
{
95-
if text_range.contains_inclusive(position) {
96-
return Some(module);
108+
fn import_at(&self, handle: &Handle, position: TextSize) -> Option<ImportIdentifier> {
109+
fn visit_stmt(x: &Stmt, find: TextSize, res: &mut Option<ImportIdentifier>) {
110+
match x {
111+
Stmt::Import(stmt_import) => {
112+
let mut parts = Vec::new();
113+
for name in stmt_import.names.iter() {
114+
parts.push(name.name.clone());
115+
if name.range.contains_inclusive(find) {
116+
*res = Some(ImportIdentifier::Module(ModuleName::from_parts(
117+
parts.clone(),
118+
)));
119+
}
120+
}
121+
}
122+
Stmt::ImportFrom(stmt_import_from) => {
123+
if let Some(id) = &stmt_import_from.module {
124+
if id.range.contains_inclusive(find) {
125+
*res = Some(ImportIdentifier::Module(ModuleName::from_name(&id.id)));
126+
} else {
127+
for name in stmt_import_from.names.iter() {
128+
if name.range.contains_inclusive(find) {
129+
*res =
130+
Some(ImportIdentifier::Name(ModuleName::from_name(&id.id)));
131+
}
132+
}
133+
}
134+
}
135+
}
136+
_ => x.recurse(&mut |x| visit_stmt(x, find, res)),
97137
}
98138
}
99-
None
139+
140+
let mut res = None;
141+
self.get_ast(handle)?
142+
.body
143+
.visit(&mut |x| visit_stmt(x, position, &mut res));
144+
res
100145
}
101146

102147
fn definition_at(&self, handle: &Handle, position: TextSize) -> Option<Key> {
@@ -133,9 +178,10 @@ impl<'a> Transaction<'a> {
133178
}
134179
}
135180
if let Some(m) = self.import_at(handle, position) {
181+
let module_name = m.module_name();
136182
return Some(Type::Module(Module::new(
137-
m.components().first().unwrap().clone(),
138-
OrderedSet::from_iter([(m)]),
183+
module_name.components().first().unwrap().clone(),
184+
OrderedSet::from_iter([module_name]),
139185
)));
140186
}
141187
let attribute = self.attribute_at(handle, position)?;
@@ -273,7 +319,8 @@ impl<'a> Transaction<'a> {
273319
));
274320
}
275321
if let Some(m) = self.import_at(handle, position) {
276-
let handle = self.import_handle(handle, m, None).ok()?;
322+
let module_name = m.module_name();
323+
let handle = self.import_handle(handle, module_name, None).ok()?;
277324
return Some((
278325
DefinitionMetadata::Module,
279326
TextRangeWithModuleInfo::new(self.get_module_info(&handle)?, TextRange::default()),

pyrefly/lib/test/lsp/definition.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def f(x: list[int], y: str, z: Literal[42]):
107107
let code_test: &str = r#"
108108
from typing import Literal
109109
from .import_provider import f
110-
# ^ ^
110+
# ^ ^ ^
111111
112112
foo: Literal[1] = 1
113113
# ^
@@ -125,6 +125,10 @@ bar = f([1], "", 42)
125125
assert_eq!(
126126
r#"
127127
# main.py
128+
3 | from .import_provider import f
129+
^
130+
Definition Result: None
131+
128132
3 | from .import_provider import f
129133
^
130134
Definition Result:

pyrefly/lib/test/lsp/hover_type.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String
2424
fn basic_test() {
2525
let code = r#"
2626
from typing import Literal
27-
# ^
27+
# ^ ^ ^
2828
def f(x: list[int], y: str, z: Literal[42]):
2929
# ^ ^ ^
3030
return x
@@ -40,6 +40,14 @@ yyy = f([1, 2, 3], "test", 42)
4040
^
4141
Hover Result: `Module[typing]`
4242
43+
2 | from typing import Literal
44+
^
45+
Hover Result: None
46+
47+
2 | from typing import Literal
48+
^
49+
Hover Result: `type[Literal]`
50+
4351
4 | def f(x: list[int], y: str, z: Literal[42]):
4452
^
4553
Hover Result: `(x: list[int], y: str, z: Literal[42]) -> list[int]`
@@ -83,6 +91,25 @@ Hover Result: `Module[typing]`
8391
);
8492
}
8593

94+
#[test]
95+
fn import_alias_test() {
96+
let code = r#"
97+
import typing as t
98+
# ^
99+
"#;
100+
let report = get_batched_lsp_operations_report(&[("main", code)], get_test_report);
101+
assert_eq!(
102+
r#"
103+
# main.py
104+
2 | import typing as t
105+
^
106+
Hover Result: `Module[typing]`
107+
"#
108+
.trim(),
109+
report.trim(),
110+
);
111+
}
112+
86113
#[test]
87114
fn duplicate_import_test() {
88115
let code = r#"
@@ -91,13 +118,12 @@ import typing
91118
# ^
92119
"#;
93120
let report = get_batched_lsp_operations_report(&[("main", code)], get_test_report);
94-
// TODO(kylei): The result should be `Module[typing]`
95121
assert_eq!(
96122
r#"
97123
# main.py
98124
3 | import typing
99125
^
100-
Hover Result: None
126+
Hover Result: `Module[typing]`
101127
"#
102128
.trim(),
103129
report.trim(),

0 commit comments

Comments
 (0)