Skip to content

Commit 268971e

Browse files
authored
Merge pull request #5365 from Textualize/cp-first-word-boost
boost for first letter matches
2 parents 44d250a + ea2a731 commit 268971e

File tree

5 files changed

+165
-64
lines changed

5 files changed

+165
-64
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
1818
- Change default quit key to `ctrl+q` https://github.com/Textualize/textual/pull/5352
1919
- Changed delete line binding on TextArea to use `ctrl+shift+x` https://github.com/Textualize/textual/pull/5352
2020
- The command palette will now select the top item automatically https://github.com/Textualize/textual/pull/5361
21+
- Implemented a better matching algorithm for the command palette https://github.com/Textualize/textual/pull/5365
2122

2223
### Fixed
2324

src/textual/containers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __init__(
267267
stretch_height: bool = True,
268268
regular: bool = False,
269269
) -> None:
270-
"""Initialize a Widget.
270+
"""
271271
272272
Args:
273273
*children: Child widgets.

src/textual/fuzzy.py

Lines changed: 148 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,151 @@
77

88
from __future__ import annotations
99

10-
from re import IGNORECASE, compile, escape
10+
from operator import itemgetter
11+
from re import IGNORECASE, escape, finditer, search
12+
from typing import Iterable, NamedTuple
1113

1214
import rich.repr
1315
from rich.style import Style
1416
from rich.text import Text
1517

16-
from textual.cache import LRUCache
18+
19+
class _Search(NamedTuple):
20+
"""Internal structure to keep track of a recursive search."""
21+
22+
candidate_offset: int = 0
23+
query_offset: int = 0
24+
offsets: tuple[int, ...] = ()
25+
26+
def branch(self, offset: int) -> tuple[_Search, _Search]:
27+
"""Branch this search when an offset is found.
28+
29+
Args:
30+
offset: Offset of a matching letter in the query.
31+
32+
Returns:
33+
A pair of search objects.
34+
"""
35+
_, query_offset, offsets = self
36+
return (
37+
_Search(offset + 1, query_offset + 1, offsets + (offset,)),
38+
_Search(offset + 1, query_offset, offsets),
39+
)
40+
41+
@property
42+
def groups(self) -> int:
43+
"""Number of groups in offsets."""
44+
groups = 1
45+
last_offset = self.offsets[0]
46+
for offset in self.offsets[1:]:
47+
if offset != last_offset + 1:
48+
groups += 1
49+
last_offset = offset
50+
return groups
51+
52+
53+
class FuzzySearch:
54+
"""Performs a fuzzy search.
55+
56+
Unlike a regex solution, this will finds all possible matches.
57+
"""
58+
59+
def __init__(self, case_sensitive: bool = False) -> None:
60+
"""Initialize fuzzy search.
61+
62+
Args:
63+
case_sensitive: Is the match case sensitive?
64+
"""
65+
self.cache: dict[tuple[str, str, bool], tuple[float, tuple[int, ...]]] = {}
66+
self.case_sensitive = case_sensitive
67+
68+
def match(self, query: str, candidate: str) -> tuple[float, tuple[int, ...]]:
69+
"""Match against a query.
70+
71+
Args:
72+
query: The fuzzy query.
73+
candidate: A candidate to check,.
74+
75+
Returns:
76+
A pair of (score, tuple of offsets). `(0, ())` for no result.
77+
"""
78+
79+
query_regex = ".*?".join(f"({escape(character)})" for character in query)
80+
if not search(
81+
query_regex, candidate, flags=0 if self.case_sensitive else IGNORECASE
82+
):
83+
# Bail out early if there is no possibility of a match
84+
return (0.0, ())
85+
86+
cache_key = (query, candidate, self.case_sensitive)
87+
if cache_key in self.cache:
88+
return self.cache[cache_key]
89+
result = max(
90+
self._match(query, candidate), key=itemgetter(0), default=(0.0, ())
91+
)
92+
self.cache[cache_key] = result
93+
return result
94+
95+
def _match(
96+
self, query: str, candidate: str
97+
) -> Iterable[tuple[float, tuple[int, ...]]]:
98+
"""Generator to do the matching.
99+
100+
Args:
101+
query: Query to match.
102+
candidate: Candidate to check against.
103+
104+
Yields:
105+
Pairs of score and tuple of offsets.
106+
"""
107+
if not self.case_sensitive:
108+
query = query.lower()
109+
candidate = candidate.lower()
110+
111+
# We need this to give a bonus to first letters.
112+
first_letters = {match.start() for match in finditer(r"\w+", candidate)}
113+
114+
def score(search: _Search) -> float:
115+
"""Sore a search.
116+
117+
Args:
118+
search: Search object.
119+
120+
Returns:
121+
Score.
122+
123+
"""
124+
# This is a heuristic, and can be tweaked for better results
125+
# Boost first letter matches
126+
score: float = sum(
127+
(2.0 if offset in first_letters else 1.0) for offset in search.offsets
128+
)
129+
# Boost to favor less groups
130+
offset_count = len(search.offsets)
131+
normalized_groups = (offset_count - (search.groups - 1)) / offset_count
132+
score *= 1 + (normalized_groups**2)
133+
return score
134+
135+
stack: list[_Search] = [_Search()]
136+
push = stack.append
137+
pop = stack.pop
138+
query_size = len(query)
139+
find = candidate.find
140+
# Limit the number of loops out of an abundance of caution.
141+
# This would be hard to reach without contrived data.
142+
remaining_loops = 200
143+
144+
while stack and (remaining_loops := remaining_loops - 1):
145+
search = pop()
146+
offset = find(query[search.query_offset], search.candidate_offset)
147+
if offset != -1:
148+
advance_branch, branch = search.branch(offset)
149+
if advance_branch.query_offset == query_size:
150+
yield score(advance_branch), advance_branch.offsets
151+
push(branch)
152+
else:
153+
push(advance_branch)
154+
push(branch)
17155

18156

19157
@rich.repr.auto
@@ -36,11 +174,8 @@ def __init__(
36174
"""
37175
self._query = query
38176
self._match_style = Style(reverse=True) if match_style is None else match_style
39-
self._query_regex = compile(
40-
".*?".join(f"({escape(character)})" for character in query),
41-
flags=0 if case_sensitive else IGNORECASE,
42-
)
43-
self._cache: LRUCache[str, float] = LRUCache(1024 * 4)
177+
self._case_sensitive = case_sensitive
178+
self.fuzzy_search = FuzzySearch()
44179

45180
@property
46181
def query(self) -> str:
@@ -52,15 +187,10 @@ def match_style(self) -> Style:
52187
"""The style that will be used to highlight hits in the matched text."""
53188
return self._match_style
54189

55-
@property
56-
def query_pattern(self) -> str:
57-
"""The regular expression pattern built from the query."""
58-
return self._query_regex.pattern
59-
60190
@property
61191
def case_sensitive(self) -> bool:
62192
"""Is this matcher case sensitive?"""
63-
return not bool(self._query_regex.flags & IGNORECASE)
193+
return self._case_sensitive
64194

65195
def match(self, candidate: str) -> float:
66196
"""Match the candidate against the query.
@@ -71,27 +201,7 @@ def match(self, candidate: str) -> float:
71201
Returns:
72202
Strength of the match from 0 to 1.
73203
"""
74-
cached = self._cache.get(candidate)
75-
if cached is not None:
76-
return cached
77-
match = self._query_regex.search(candidate)
78-
if match is None:
79-
score = 0.0
80-
else:
81-
assert match.lastindex is not None
82-
offsets = [
83-
match.span(group_no)[0] for group_no in range(1, match.lastindex + 1)
84-
]
85-
group_count = 0
86-
last_offset = -2
87-
for offset in offsets:
88-
if offset > last_offset + 1:
89-
group_count += 1
90-
last_offset = offset
91-
92-
score = 1.0 - ((group_count - 1) / len(candidate))
93-
self._cache[candidate] = score
94-
return score
204+
return self.fuzzy_search.match(self.query, candidate)[0]
95205

96206
def highlight(self, candidate: str) -> Text:
97207
"""Highlight the candidate with the fuzzy match.
@@ -102,20 +212,11 @@ def highlight(self, candidate: str) -> Text:
102212
Returns:
103213
A [rich.text.Text][`Text`] object with highlighted matches.
104214
"""
105-
match = self._query_regex.search(candidate)
106215
text = Text.from_markup(candidate)
107-
if match is None:
216+
score, offsets = self.fuzzy_search.match(self.query, candidate)
217+
if not score:
108218
return text
109-
assert match.lastindex is not None
110-
if self._query in text.plain:
111-
# Favor complete matches
112-
offset = text.plain.index(self._query)
113-
text.stylize(self._match_style, offset, offset + len(self._query))
114-
else:
115-
offsets = [
116-
match.span(group_no)[0] for group_no in range(1, match.lastindex + 1)
117-
]
118-
for offset in offsets:
219+
for offset in offsets:
220+
if not candidate[offset].isspace():
119221
text.stylize(self._match_style, offset, offset + 1)
120-
121222
return text

tests/snapshot_tests/test_snapshots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,7 @@ def test_example_color_command(snap_compare):
15101510
"""Test the color_command example."""
15111511
assert snap_compare(
15121512
EXAMPLES_DIR / "color_command.py",
1513-
press=[App.COMMAND_PALETTE_BINDING, "r", "e", "d", "down", "enter"],
1513+
press=[App.COMMAND_PALETTE_BINDING, "r", "e", "d", "enter"],
15141514
)
15151515

15161516

tests/test_fuzzy.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,24 @@
44
from textual.fuzzy import Matcher
55

66

7-
def test_match():
8-
matcher = Matcher("foo.bar")
7+
def test_no_match():
8+
"""Check non matching score of zero."""
9+
matcher = Matcher("x")
10+
assert matcher.match("foo") == 0
11+
912

10-
# No match
11-
assert matcher.match("egg") == 0
12-
assert matcher.match("") == 0
13+
def test_match_single_group():
14+
"""Check that single groups rang higher."""
15+
matcher = Matcher("abc")
16+
assert matcher.match("foo abc bar") > matcher.match("fooa barc")
1317

14-
# Perfect match
15-
assert matcher.match("foo.bar") == 1.0
16-
# Perfect match (with superfluous characters)
17-
assert matcher.match("foo.bar sdf") == 1.0
18-
assert matcher.match("xz foo.bar sdf") == 1.0
1918

20-
# Partial matches
21-
# 2 Groups
22-
assert matcher.match("foo egg.bar") == 1.0 - 1 / 11
19+
def test_boosted_matches():
20+
"""Check first word matchers rank higher."""
21+
matcher = Matcher("ss")
2322

24-
# 3 Groups
25-
assert matcher.match("foo .ba egg r") == 1.0 - 2 / 13
23+
# First word matchers should score higher
24+
assert matcher.match("Save Screenshot") > matcher.match("Show Keys abcde")
2625

2726

2827
def test_highlight():

0 commit comments

Comments
 (0)