Skip to content

Commit 121bff6

Browse files
committed
Add conditional support for UnionType in DOM queries for Python 3.10+
- Use try/except to conditionally import UnionType and get_args - Make UnionType overload signature conditional - Add runtime checks before using UnionType functionality - Maintain backward compatibility with Python <3.10 This change allows users on Python 3.10+ to query with Union types while preserving compatibility with earlier Python versions.
1 parent a0969a6 commit 121bff6

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/textual/dom.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@
2121
Type,
2222
TypeVar,
2323
cast,
24-
get_args,
2524
overload,
2625
)
27-
from types import UnionType
26+
27+
try:
28+
from types import UnionType
29+
from typing import get_args
30+
except ImportError:
31+
UnionType = None # Type will not exist in earlier versions
32+
get_args = None # Not needed for earlier versions
2833

2934
import rich.repr
3035
from rich.highlighter import ReprHighlighter
@@ -1368,8 +1373,10 @@ def query(self, selector: str | None = None) -> DOMQuery[Widget]: ...
13681373
@overload
13691374
def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ...
13701375

1371-
@overload
1372-
def query(self, selector: UnionType) -> DOMQuery[Widget]: ...
1376+
if UnionType is not None:
1377+
1378+
@overload
1379+
def query(self, selector: UnionType) -> DOMQuery[Widget]: ...
13731380

13741381
def query(
13751382
self, selector: str | type[QueryType] | UnionType | None = None
@@ -1390,7 +1397,7 @@ def query(
13901397

13911398
if isinstance(selector, str) or selector is None:
13921399
return DOMQuery[Widget](self, filter=selector)
1393-
elif isinstance(selector, UnionType):
1400+
elif UnionType is not None and isinstance(selector, UnionType):
13941401
# Get all types from the union, including nested unions
13951402
def get_all_types(union_type):
13961403
types = set()
@@ -1400,12 +1407,14 @@ def get_all_types(union_type):
14001407
else:
14011408
types.add(t)
14021409
return types
1403-
1410+
14041411
# Validate all types in the union are Widget subclasses
14051412
types_in_union = get_args(selector)
1406-
if not all(isinstance(t, type) and issubclass(t, Widget) for t in types_in_union):
1413+
if not all(
1414+
isinstance(t, type) and issubclass(t, Widget) for t in types_in_union
1415+
):
14071416
raise TypeError("All types in Union must be Widget subclasses")
1408-
1417+
14091418
# Convert Union type to comma-separated string of class names
14101419
type_names = [t.__name__ for t in types_in_union]
14111420
selector_str = ", ".join(type_names)

0 commit comments

Comments
 (0)