1919 from .agent import SerenaAgent
2020
2121log = logging .getLogger (__name__ )
22+ NAME_PATH_SEP = "/"
2223
2324
2425@dataclass
@@ -113,45 +114,65 @@ def is_neighbouring_definition_separated_by_empty_line(self) -> bool:
113114 """
114115
115116
116- class LanguageServerSymbol (Symbol , ToStringMixin ):
117- _NAME_PATH_SEP = "/"
118-
119- @staticmethod
120- def match_name_path (
121- name_path : str ,
122- symbol_name_path_parts : list [str ],
123- substring_matching : bool ,
124- ) -> bool :
117+ class NamePathMatcher (ToStringMixin ):
118+ def __init__ (self , name_path_expr : str , substring_matching : bool ) -> None :
125119 """
126- Checks if a given `name_path` matches a symbol's qualified name parts.
127- See docstring of `Symbol.find` for more details.
120+ :param name_path_expr: the name path expression to match against
121+ :param substring_matching: whether to use substring matching for the last segment
128122 """
129- assert name_path , "name_path must not be empty"
130- assert symbol_name_path_parts , "symbol_name_path_parts must not be empty"
131- name_path_sep = LanguageServerSymbol ._NAME_PATH_SEP
123+ assert name_path_expr , "name_path must not be empty"
124+ self ._expr = name_path_expr
125+ self ._substring_matching = substring_matching
126+ self ._is_absolute_pattern = name_path_expr .startswith (NAME_PATH_SEP )
127+ self ._pattern_parts = name_path_expr .lstrip (NAME_PATH_SEP ).rstrip (NAME_PATH_SEP ).split (NAME_PATH_SEP )
128+
129+ # extract overload index "[idx]" if present at end of last part
130+ self ._overload_idx : int | None = None
131+ last_part = self ._pattern_parts [- 1 ]
132+ if last_part .endswith ("]" ) and "[" in last_part :
133+ bracket_idx = last_part .rfind ("[" )
134+ index_part = last_part [bracket_idx + 1 : - 1 ]
135+ if index_part .isdigit ():
136+ self ._pattern_parts [- 1 ] = last_part [:bracket_idx ]
137+ self ._overload_idx = int (index_part )
132138
133- is_absolute_pattern = name_path .startswith (name_path_sep )
134- pattern_parts = name_path .lstrip (name_path_sep ).rstrip (name_path_sep ).split (name_path_sep )
139+ def _tostring_includes (self ) -> list [str ]:
140+ return ["_expr" ]
141+
142+ def matches_ls_symbol (self , symbol : "LanguageServerSymbol" ) -> bool :
143+ return self .matches_components (symbol .get_name_path_parts (), symbol .overload_idx )
135144
145+ def matches_components (self , symbol_name_path_parts : list [str ], overload_idx : int | None ) -> bool :
136146 # filtering based on ancestors
137- if len (pattern_parts ) > len (symbol_name_path_parts ):
147+ if len (self . _pattern_parts ) > len (symbol_name_path_parts ):
138148 # can't possibly match if pattern has more parts than symbol
139149 return False
140- if is_absolute_pattern and len (pattern_parts ) != len (symbol_name_path_parts ):
150+ if self . _is_absolute_pattern and len (self . _pattern_parts ) != len (symbol_name_path_parts ):
141151 # for absolute patterns, the number of parts must match exactly
142152 return False
143- if symbol_name_path_parts [- len (pattern_parts ) : - 1 ] != pattern_parts [:- 1 ]:
153+ if symbol_name_path_parts [- len (self . _pattern_parts ) : - 1 ] != self . _pattern_parts [:- 1 ]:
144154 # ancestors must match
145155 return False
146156
147157 # matching the last part of the symbol name
148- name_to_match = pattern_parts [- 1 ]
158+ name_to_match = self . _pattern_parts [- 1 ]
149159 symbol_name = symbol_name_path_parts [- 1 ]
150- if substring_matching :
151- return name_to_match in symbol_name
160+ if self ._substring_matching :
161+ if name_to_match not in symbol_name :
162+ return False
152163 else :
153- return name_to_match == symbol_name
164+ if name_to_match != symbol_name :
165+ return False
154166
167+ # check for matching overload index
168+ if self ._overload_idx is not None :
169+ if overload_idx != self ._overload_idx :
170+ return False
171+
172+ return True
173+
174+
175+ class LanguageServerSymbol (Symbol , ToStringMixin ):
155176 def __init__ (self , symbol_root_from_ls : UnifiedSymbolInformation ) -> None :
156177 self .symbol_root = symbol_root_from_ls
157178
@@ -173,6 +194,10 @@ def kind(self) -> str:
173194 def symbol_kind (self ) -> SymbolKind :
174195 return self .symbol_root ["kind" ]
175196
197+ @property
198+ def overload_idx (self ) -> int | None :
199+ return self .symbol_root .get ("overload_idx" )
200+
176201 def is_neighbouring_definition_separated_by_empty_line (self ) -> bool :
177202 return self .symbol_kind in (SymbolKind .Function , SymbolKind .Method , SymbolKind .Class , SymbolKind .Interface , SymbolKind .Struct )
178203
@@ -256,9 +281,13 @@ def body(self) -> str | None:
256281
257282 def get_name_path (self ) -> str :
258283 """
259- Get the name path of the symbol (e.g. "class/method/inner_function").
284+ Get the name path of the symbol, e.g. "class/method/inner_function" or
285+ "class/method[1]" (overloaded method with identifying index).
260286 """
261- return self ._NAME_PATH_SEP .join (self .get_name_path_parts ())
287+ name_path = NAME_PATH_SEP .join (self .get_name_path_parts ())
288+ if "overload_idx" in self .symbol_root :
289+ name_path += f"[{ self .symbol_root ['overload_idx' ]} ]"
290+ return name_path
262291
263292 def get_name_path_parts (self ) -> list [str ]:
264293 """
@@ -321,25 +350,22 @@ def find(
321350 For example, passing `/class` will match only against top-level symbols like `class` but not against `nested_class/class`.
322351 Passing `/class/method` will match against `class/method` but not `nested_class/class/method` or `method`.
323352
324- :param name_path: the name path to match against
353+ :param name_path: the name path expression to match against
325354 :param substring_matching: whether to use substring matching (as opposed to exact matching)
326355 of the last segment of `name_path` against the symbol name.
327356 :param include_kinds: an optional sequence of ints representing the LSP symbol kind.
328357 If provided, only symbols of the given kinds will be included in the result.
329358 :param exclude_kinds: If provided, symbols of the given kinds will be excluded from the result.
330359 """
331360 result = []
361+ name_path_matcher = NamePathMatcher (name_path , substring_matching )
332362
333363 def should_include (s : "LanguageServerSymbol" ) -> bool :
334364 if include_kinds is not None and s .symbol_kind not in include_kinds :
335365 return False
336366 if exclude_kinds is not None and s .symbol_kind in exclude_kinds :
337367 return False
338- return LanguageServerSymbol .match_name_path (
339- name_path = name_path ,
340- symbol_name_path_parts = s .get_name_path_parts (),
341- substring_matching = substring_matching ,
342- )
368+ return name_path_matcher .matches_ls_symbol (s )
343369
344370 def traverse (s : "LanguageServerSymbol" ) -> None :
345371 if should_include (s ):
0 commit comments