Skip to content

Commit d31cfab

Browse files
committed
Add overload_idx to name path matching, reifying NamePathMatcher for efficency
Implements #515
1 parent 55b4a66 commit d31cfab

File tree

2 files changed

+76
-34
lines changed

2 files changed

+76
-34
lines changed

src/serena/symbol.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .agent import SerenaAgent
2020

2121
log = logging.getLogger(__name__)
22+
NAME_PATH_SEP = "/"
2223

2324

2425
@dataclass
@@ -113,45 +114,65 @@ def is_neighbouring_definition_separated_by_empty_line(self) -> bool:
113114
"""
114115

115116

116-
class LanguageServerSymbol(Symbol, ToStringMixin):
117-
_NAME_PATH_SEP = "/"
118-
119-
@staticmethod
120-
def match_name_path(
121-
name_path: str,
122-
symbol_name_path_parts: list[str],
123-
substring_matching: bool,
124-
) -> bool:
117+
class NamePathMatcher(ToStringMixin):
118+
def __init__(self, name_path_expr: str, substring_matching: bool) -> None:
125119
"""
126-
Checks if a given `name_path` matches a symbol's qualified name parts.
127-
See docstring of `Symbol.find` for more details.
120+
:param name_path_expr: the name path expression to match against
121+
:param substring_matching: whether to use substring matching for the last segment
128122
"""
129-
assert name_path, "name_path must not be empty"
130-
assert symbol_name_path_parts, "symbol_name_path_parts must not be empty"
131-
name_path_sep = LanguageServerSymbol._NAME_PATH_SEP
123+
assert name_path_expr, "name_path must not be empty"
124+
self._expr = name_path_expr
125+
self._substring_matching = substring_matching
126+
self._is_absolute_pattern = name_path_expr.startswith(NAME_PATH_SEP)
127+
self._pattern_parts = name_path_expr.lstrip(NAME_PATH_SEP).rstrip(NAME_PATH_SEP).split(NAME_PATH_SEP)
128+
129+
# extract overload index "[idx]" if present at end of last part
130+
self._overload_idx: int | None = None
131+
last_part = self._pattern_parts[-1]
132+
if last_part.endswith("]") and "[" in last_part:
133+
bracket_idx = last_part.rfind("[")
134+
index_part = last_part[bracket_idx + 1 : -1]
135+
if index_part.isdigit():
136+
self._pattern_parts[-1] = last_part[:bracket_idx]
137+
self._overload_idx = int(index_part)
132138

133-
is_absolute_pattern = name_path.startswith(name_path_sep)
134-
pattern_parts = name_path.lstrip(name_path_sep).rstrip(name_path_sep).split(name_path_sep)
139+
def _tostring_includes(self) -> list[str]:
140+
return ["_expr"]
141+
142+
def matches_ls_symbol(self, symbol: "LanguageServerSymbol") -> bool:
143+
return self.matches_components(symbol.get_name_path_parts(), symbol.overload_idx)
135144

145+
def matches_components(self, symbol_name_path_parts: list[str], overload_idx: int | None) -> bool:
136146
# filtering based on ancestors
137-
if len(pattern_parts) > len(symbol_name_path_parts):
147+
if len(self._pattern_parts) > len(symbol_name_path_parts):
138148
# can't possibly match if pattern has more parts than symbol
139149
return False
140-
if is_absolute_pattern and len(pattern_parts) != len(symbol_name_path_parts):
150+
if self._is_absolute_pattern and len(self._pattern_parts) != len(symbol_name_path_parts):
141151
# for absolute patterns, the number of parts must match exactly
142152
return False
143-
if symbol_name_path_parts[-len(pattern_parts) : -1] != pattern_parts[:-1]:
153+
if symbol_name_path_parts[-len(self._pattern_parts) : -1] != self._pattern_parts[:-1]:
144154
# ancestors must match
145155
return False
146156

147157
# matching the last part of the symbol name
148-
name_to_match = pattern_parts[-1]
158+
name_to_match = self._pattern_parts[-1]
149159
symbol_name = symbol_name_path_parts[-1]
150-
if substring_matching:
151-
return name_to_match in symbol_name
160+
if self._substring_matching:
161+
if name_to_match not in symbol_name:
162+
return False
152163
else:
153-
return name_to_match == symbol_name
164+
if name_to_match != symbol_name:
165+
return False
154166

167+
# check for matching overload index
168+
if self._overload_idx is not None:
169+
if overload_idx != self._overload_idx:
170+
return False
171+
172+
return True
173+
174+
175+
class LanguageServerSymbol(Symbol, ToStringMixin):
155176
def __init__(self, symbol_root_from_ls: UnifiedSymbolInformation) -> None:
156177
self.symbol_root = symbol_root_from_ls
157178

@@ -173,6 +194,10 @@ def kind(self) -> str:
173194
def symbol_kind(self) -> SymbolKind:
174195
return self.symbol_root["kind"]
175196

197+
@property
198+
def overload_idx(self) -> int | None:
199+
return self.symbol_root.get("overload_idx")
200+
176201
def is_neighbouring_definition_separated_by_empty_line(self) -> bool:
177202
return self.symbol_kind in (SymbolKind.Function, SymbolKind.Method, SymbolKind.Class, SymbolKind.Interface, SymbolKind.Struct)
178203

@@ -256,9 +281,13 @@ def body(self) -> str | None:
256281

257282
def get_name_path(self) -> str:
258283
"""
259-
Get the name path of the symbol (e.g. "class/method/inner_function").
284+
Get the name path of the symbol, e.g. "class/method/inner_function" or
285+
"class/method[1]" (overloaded method with identifying index).
260286
"""
261-
return self._NAME_PATH_SEP.join(self.get_name_path_parts())
287+
name_path = NAME_PATH_SEP.join(self.get_name_path_parts())
288+
if "overload_idx" in self.symbol_root:
289+
name_path += f"[{self.symbol_root['overload_idx']}]"
290+
return name_path
262291

263292
def get_name_path_parts(self) -> list[str]:
264293
"""
@@ -321,25 +350,22 @@ def find(
321350
For example, passing `/class` will match only against top-level symbols like `class` but not against `nested_class/class`.
322351
Passing `/class/method` will match against `class/method` but not `nested_class/class/method` or `method`.
323352
324-
:param name_path: the name path to match against
353+
:param name_path: the name path expression to match against
325354
:param substring_matching: whether to use substring matching (as opposed to exact matching)
326355
of the last segment of `name_path` against the symbol name.
327356
:param include_kinds: an optional sequence of ints representing the LSP symbol kind.
328357
If provided, only symbols of the given kinds will be included in the result.
329358
:param exclude_kinds: If provided, symbols of the given kinds will be excluded from the result.
330359
"""
331360
result = []
361+
name_path_matcher = NamePathMatcher(name_path, substring_matching)
332362

333363
def should_include(s: "LanguageServerSymbol") -> bool:
334364
if include_kinds is not None and s.symbol_kind not in include_kinds:
335365
return False
336366
if exclude_kinds is not None and s.symbol_kind in exclude_kinds:
337367
return False
338-
return LanguageServerSymbol.match_name_path(
339-
name_path=name_path,
340-
symbol_name_path_parts=s.get_name_path_parts(),
341-
substring_matching=substring_matching,
342-
)
368+
return name_path_matcher.matches_ls_symbol(s)
343369

344370
def traverse(s: "LanguageServerSymbol") -> None:
345371
if should_include(s):

test/serena/test_symbol.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from src.serena.symbol import LanguageServerSymbol
3+
from serena.symbol import NamePathMatcher
44

55

66
class TestSymbolNameMatching:
@@ -77,7 +77,7 @@ def _create_assertion_error_message(
7777
)
7878
def test_match_simple_name(self, name_path_pattern, symbol_name_path_parts, is_substring_match, expected):
7979
"""Tests matching for simple names (no '/' in pattern)."""
80-
result = LanguageServerSymbol.match_name_path(name_path_pattern, symbol_name_path_parts, is_substring_match)
80+
result = NamePathMatcher(name_path_pattern, is_substring_match).matches_components(symbol_name_path_parts, None)
8181
error_msg = self._create_assertion_error_message(name_path_pattern, symbol_name_path_parts, is_substring_match, expected, result)
8282
assert result == expected, error_msg
8383

@@ -157,6 +157,22 @@ def test_match_simple_name(self, name_path_pattern, symbol_name_path_parts, is_s
157157
)
158158
def test_match_name_path_pattern_path_len_2(self, name_path_pattern, symbol_name_path_parts, is_substring_match, expected):
159159
"""Tests matching for qualified names (e.g. 'module/class/func')."""
160-
result = LanguageServerSymbol.match_name_path(name_path_pattern, symbol_name_path_parts, is_substring_match)
160+
result = NamePathMatcher(name_path_pattern, is_substring_match).matches_components(symbol_name_path_parts, None)
161161
error_msg = self._create_assertion_error_message(name_path_pattern, symbol_name_path_parts, is_substring_match, expected, result)
162162
assert result == expected, error_msg
163+
164+
@pytest.mark.parametrize(
165+
"name_path_pattern, symbol_name_path_parts, symbol_overload_idx, expected",
166+
[
167+
pytest.param("bar/foo", ["bar", "foo"], 0, True, id="R: 'bar/foo' matches ['bar', 'foo'] with overload_index=0"),
168+
pytest.param("bar/foo", ["bar", "foo"], 1, True, id="R: 'bar/foo' matches ['bar', 'foo'] with overload_index=1"),
169+
pytest.param("bar/foo[0]", ["bar", "foo"], 0, True, id="R: 'bar/foo[0]' matches ['bar', 'foo'] with overload_index=0"),
170+
pytest.param("bar/foo[1]", ["bar", "foo"], 0, False, id="R: 'bar/foo[1]' does not match ['bar', 'foo'] with overload_index=0"),
171+
],
172+
)
173+
def test_match_name_path_pattern_with_overload_idx(self, name_path_pattern, symbol_name_path_parts, symbol_overload_idx, expected):
174+
"""Tests matching for qualified names (e.g. 'module/class/func')."""
175+
matcher = NamePathMatcher(name_path_pattern, False)
176+
result = matcher.matches_components(symbol_name_path_parts, symbol_overload_idx)
177+
error_msg = self._create_assertion_error_message(name_path_pattern, symbol_name_path_parts, False, expected, result)
178+
assert result == expected, error_msg

0 commit comments

Comments
 (0)