Skip to content

Commit 7dcd1d3

Browse files
authored
Merge pull request #5578 from zarch/uniontype_on_query_method
feat: Add UnionType support to query method
2 parents 64cba2f + 631199e commit 7dcd1d3

File tree

3 files changed

+121
-5
lines changed

3 files changed

+121
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
6868
### Added
6969

7070
- Added `pointer_x`, `pointer_y`, `pointer_screen_x`, and `pointer_screen_y` attributes to mouse events https://github.com/Textualize/textual/pull/5556
71+
- DOMNode.query now accepts UnionType for selector, e.g. `self.query(Input | Select )` https://github.com/Textualize/textual/pull/5578
7172

7273
### Changed
7374

src/textual/dom.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
overload,
2525
)
2626

27+
try:
28+
from types import UnionType
29+
from typing import get_args
30+
except ImportError:
31+
UnionType = None # Type will not exist in earlier versions
32+
get_args = None # Not needed for earlier versions
33+
2734
import rich.repr
2835
from rich.highlighter import ReprHighlighter
2936
from rich.style import Style
@@ -95,9 +102,9 @@ def check_identifiers(description: str, *names: str) -> None:
95102
description: Description of where identifier is used for error message.
96103
*names: Identifiers to check.
97104
"""
98-
match = _re_identifier.fullmatch
105+
fullmatch = _re_identifier.fullmatch
99106
for name in names:
100-
if match(name) is None:
107+
if fullmatch(name) is None:
101108
raise BadIdentifier(
102109
f"{name!r} is an invalid {description}; "
103110
"identifiers must contain only letters, numbers, underscores, or hyphens, and must not begin with a number."
@@ -1366,22 +1373,52 @@ def query(self, selector: str | None = None) -> DOMQuery[Widget]: ...
13661373
@overload
13671374
def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ...
13681375

1376+
if UnionType is not None:
1377+
1378+
@overload
1379+
def query(self, selector: UnionType) -> DOMQuery[Widget]: ...
1380+
13691381
def query(
1370-
self, selector: str | type[QueryType] | None = None
1382+
self, selector: str | type[QueryType] | UnionType | None = None
13711383
) -> DOMQuery[Widget] | DOMQuery[QueryType]:
1372-
"""Query the DOM for children that match a selector or widget type.
1384+
"""Query the DOM for children that match a selector or widget type, or a union of widget types.
13731385
13741386
Args:
1375-
selector: A CSS selector, widget type, or `None` for all nodes.
1387+
selector: A CSS selector, widget type, a union of widget types, or `None` for all nodes.
13761388
13771389
Returns:
13781390
A query object.
1391+
1392+
Raises:
1393+
TypeError: If any type in a Union is not a Widget subclass.
13791394
"""
13801395
from textual.css.query import DOMQuery, QueryType
13811396
from textual.widget import Widget
13821397

13831398
if isinstance(selector, str) or selector is None:
13841399
return DOMQuery[Widget](self, filter=selector)
1400+
elif UnionType is not None and isinstance(selector, UnionType):
1401+
# Get all types from the union, including nested unions
1402+
def get_all_types(union_type):
1403+
types = set()
1404+
for t in get_args(union_type):
1405+
if isinstance(t, UnionType):
1406+
types.update(get_all_types(t))
1407+
else:
1408+
types.add(t)
1409+
return types
1410+
1411+
# Validate all types in the union are Widget subclasses
1412+
types_in_union = get_args(selector)
1413+
if not all(
1414+
isinstance(t, type) and issubclass(t, Widget) for t in types_in_union
1415+
):
1416+
raise TypeError("All types in Union must be Widget subclasses")
1417+
1418+
# Convert Union type to comma-separated string of class names
1419+
type_names = [t.__name__ for t in types_in_union]
1420+
selector_str = ", ".join(type_names)
1421+
return DOMQuery[Widget](self, filter=selector_str)
13851422
else:
13861423
return DOMQuery[QueryType](self, filter=selector.__name__)
13871424

tests/test_dom.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import pytest
22

3+
from textual.app import App
34
from textual.css.errors import StyleValueError
45
from textual.dom import BadIdentifier, DOMNode
6+
from textual.widget import Widget
7+
from textual.widgets import Input, Select, Static
58

69

710
def test_display_default():
@@ -280,3 +283,78 @@ def test_id_validation(identifier: str):
280283
"""Regression tests for https://github.com/Textualize/textual/issues/3954."""
281284
with pytest.raises(BadIdentifier):
282285
DOMNode(id=identifier)
286+
287+
288+
class SimpleApp(App):
289+
def compose(self):
290+
yield Input(id="input1")
291+
yield Select([], id="select1")
292+
yield Static("Hello", id="static1")
293+
yield Input(id="input2")
294+
295+
296+
async def test_query_union_type():
297+
# Test with a UnionType
298+
simple_app = SimpleApp()
299+
async with simple_app.run_test():
300+
results = simple_app.query(Input | Select)
301+
assert len(results) == 3
302+
assert {w.id for w in results} == {"input1", "select1", "input2"}
303+
304+
# Test with a single type
305+
results2 = simple_app.query(Input)
306+
assert len(results2) == 2
307+
assert {w.id for w in results2} == {"input1", "input2"}
308+
309+
# Test with string selector
310+
results3 = simple_app.query("#input1")
311+
assert len(results3) == 1
312+
assert results3[0].id == "input1"
313+
314+
315+
async def test_query_nested_unions():
316+
"""Test handling of nested unions."""
317+
318+
simple_app = SimpleApp()
319+
async with simple_app.run_test():
320+
# Create nested union types
321+
InputOrSelect = Input | Select
322+
InputSelectOrStatic = InputOrSelect | Static
323+
324+
# Test nested union query
325+
results = simple_app.query(InputSelectOrStatic)
326+
327+
# Verify that we find all our explicitly defined widgets
328+
widget_ids = {w.id for w in results if w.id is not None}
329+
expected_ids = {"input1", "select1", "static1", "input2"}
330+
assert expected_ids.issubset(widget_ids), "Not all expected widgets were found"
331+
332+
# Verify we get the right types of widgets
333+
assert all(
334+
isinstance(w, (Input, Select, Static)) for w in results
335+
), "Found unexpected widget types"
336+
337+
# Verify each expected widget appears exactly once
338+
for expected_id in expected_ids:
339+
matching_widgets = [w for w in results if w.id == expected_id]
340+
assert (
341+
len(matching_widgets) == 1
342+
), f"Widget with id {expected_id} should appear exactly once"
343+
344+
345+
async def test_query_empty_union():
346+
"""Test querying with empty or invalid unions."""
347+
348+
class AnotherWidget(Widget):
349+
pass
350+
351+
simple_app = SimpleApp()
352+
async with simple_app.run_test():
353+
354+
# Test with a type that exists but has no matches
355+
results = simple_app.query(AnotherWidget)
356+
assert len(results) == 0
357+
358+
# Test with widget union that has no matches
359+
results = simple_app.query(AnotherWidget | AnotherWidget)
360+
assert len(results) == 0

0 commit comments

Comments
 (0)