Skip to content

Commit f7d2be1

Browse files
authored
Merge pull request #5679 from Textualize/revert-5578-uniontype_on_query_method
Revert "feat: Add UnionType support to query method"
2 parents 048040f + c39fb9e commit f7d2be1

File tree

2 files changed

+5
-120
lines changed

2 files changed

+5
-120
lines changed

src/textual/dom.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,6 @@
2424
overload,
2525
)
2626

27-
try:
28-
from types import UnionType
29-
from typing import get_args
30-
except ImportError:
31-
UnionType = None # Type will not exist in earlier versions
32-
get_args = None # Not needed for earlier versions
33-
3427
import rich.repr
3528
from rich.highlighter import ReprHighlighter
3629
from rich.style import Style
@@ -102,9 +95,9 @@ def check_identifiers(description: str, *names: str) -> None:
10295
description: Description of where identifier is used for error message.
10396
*names: Identifiers to check.
10497
"""
105-
fullmatch = _re_identifier.fullmatch
98+
match = _re_identifier.fullmatch
10699
for name in names:
107-
if fullmatch(name) is None:
100+
if match(name) is None:
108101
raise BadIdentifier(
109102
f"{name!r} is an invalid {description}; "
110103
"identifiers must contain only letters, numbers, underscores, or hyphens, and must not begin with a number."
@@ -1373,52 +1366,22 @@ def query(self, selector: str | None = None) -> DOMQuery[Widget]: ...
13731366
@overload
13741367
def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ...
13751368

1376-
if UnionType is not None:
1377-
1378-
@overload
1379-
def query(self, selector: UnionType) -> DOMQuery[Widget]: ...
1380-
13811369
def query(
1382-
self, selector: str | type[QueryType] | UnionType | None = None
1370+
self, selector: str | type[QueryType] | None = None
13831371
) -> DOMQuery[Widget] | DOMQuery[QueryType]:
1384-
"""Query the DOM for children that match a selector or widget type, or a union of widget types.
1372+
"""Query the DOM for children that match a selector or widget type.
13851373
13861374
Args:
1387-
selector: A CSS selector, widget type, a union of widget types, or `None` for all nodes.
1375+
selector: A CSS selector, widget type, or `None` for all nodes.
13881376
13891377
Returns:
13901378
A query object.
1391-
1392-
Raises:
1393-
TypeError: If any type in a Union is not a Widget subclass.
13941379
"""
13951380
from textual.css.query import DOMQuery, QueryType
13961381
from textual.widget import Widget
13971382

13981383
if isinstance(selector, str) or selector is None:
13991384
return DOMQuery[Widget](self, filter=selector)
1400-
elif UnionType is not None and isinstance(selector, UnionType):
1401-
# Get all types from the union, including nested unions
1402-
def get_all_types(union_type):
1403-
types = set()
1404-
for t in get_args(union_type):
1405-
if isinstance(t, UnionType):
1406-
types.update(get_all_types(t))
1407-
else:
1408-
types.add(t)
1409-
return types
1410-
1411-
# Validate all types in the union are Widget subclasses
1412-
types_in_union = get_args(selector)
1413-
if not all(
1414-
isinstance(t, type) and issubclass(t, Widget) for t in types_in_union
1415-
):
1416-
raise TypeError("All types in Union must be Widget subclasses")
1417-
1418-
# Convert Union type to comma-separated string of class names
1419-
type_names = [t.__name__ for t in types_in_union]
1420-
selector_str = ", ".join(type_names)
1421-
return DOMQuery[Widget](self, filter=selector_str)
14221385
else:
14231386
return DOMQuery[QueryType](self, filter=selector.__name__)
14241387

tests/test_dom.py

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import pytest
22

3-
from textual.app import App
43
from textual.css.errors import StyleValueError
54
from textual.dom import BadIdentifier, DOMNode
6-
from textual.widget import Widget
7-
from textual.widgets import Input, Select, Static
85

96

107
def test_display_default():
@@ -283,78 +280,3 @@ def test_id_validation(identifier: str):
283280
"""Regression tests for https://github.com/Textualize/textual/issues/3954."""
284281
with pytest.raises(BadIdentifier):
285282
DOMNode(id=identifier)
286-
287-
288-
class SimpleApp(App):
289-
def compose(self):
290-
yield Input(id="input1")
291-
yield Select([], id="select1")
292-
yield Static("Hello", id="static1")
293-
yield Input(id="input2")
294-
295-
296-
async def test_query_union_type():
297-
# Test with a UnionType
298-
simple_app = SimpleApp()
299-
async with simple_app.run_test():
300-
results = simple_app.query(Input | Select)
301-
assert len(results) == 3
302-
assert {w.id for w in results} == {"input1", "select1", "input2"}
303-
304-
# Test with a single type
305-
results2 = simple_app.query(Input)
306-
assert len(results2) == 2
307-
assert {w.id for w in results2} == {"input1", "input2"}
308-
309-
# Test with string selector
310-
results3 = simple_app.query("#input1")
311-
assert len(results3) == 1
312-
assert results3[0].id == "input1"
313-
314-
315-
async def test_query_nested_unions():
316-
"""Test handling of nested unions."""
317-
318-
simple_app = SimpleApp()
319-
async with simple_app.run_test():
320-
# Create nested union types
321-
InputOrSelect = Input | Select
322-
InputSelectOrStatic = InputOrSelect | Static
323-
324-
# Test nested union query
325-
results = simple_app.query(InputSelectOrStatic)
326-
327-
# Verify that we find all our explicitly defined widgets
328-
widget_ids = {w.id for w in results if w.id is not None}
329-
expected_ids = {"input1", "select1", "static1", "input2"}
330-
assert expected_ids.issubset(widget_ids), "Not all expected widgets were found"
331-
332-
# Verify we get the right types of widgets
333-
assert all(
334-
isinstance(w, (Input, Select, Static)) for w in results
335-
), "Found unexpected widget types"
336-
337-
# Verify each expected widget appears exactly once
338-
for expected_id in expected_ids:
339-
matching_widgets = [w for w in results if w.id == expected_id]
340-
assert (
341-
len(matching_widgets) == 1
342-
), f"Widget with id {expected_id} should appear exactly once"
343-
344-
345-
async def test_query_empty_union():
346-
"""Test querying with empty or invalid unions."""
347-
348-
class AnotherWidget(Widget):
349-
pass
350-
351-
simple_app = SimpleApp()
352-
async with simple_app.run_test():
353-
354-
# Test with a type that exists but has no matches
355-
results = simple_app.query(AnotherWidget)
356-
assert len(results) == 0
357-
358-
# Test with widget union that has no matches
359-
results = simple_app.query(AnotherWidget | AnotherWidget)
360-
assert len(results) == 0

0 commit comments

Comments
 (0)