Skip to content

Commit 254f6c5

Browse files
authored
DataTree should not be "Generic" (#9445)
* DataTree should not be "Generic" DataTree isn't a Generic tree type. It's a specific tree type -- the nodes are DataTree objects. This was resulting in many cases where mypy insisting on explicit type annotations, e.g., `tree: DataTree = DataTree(...)`, which is unnecessary and annoying boilerplate. * Fix type error * type ignore
1 parent 12c690f commit 254f6c5

File tree

2 files changed

+64
-65
lines changed

2 files changed

+64
-65
lines changed

xarray/core/datatree.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Mapping,
1212
)
1313
from html import escape
14-
from typing import TYPE_CHECKING, Any, Generic, Literal, NoReturn, Union, overload
14+
from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload
1515

1616
from xarray.core import utils
1717
from xarray.core.alignment import align
@@ -37,7 +37,7 @@
3737
from xarray.core.indexes import Index, Indexes
3838
from xarray.core.merge import dataset_update_method
3939
from xarray.core.options import OPTIONS as XR_OPTS
40-
from xarray.core.treenode import NamedNode, NodePath, Tree
40+
from xarray.core.treenode import NamedNode, NodePath
4141
from xarray.core.utils import (
4242
Default,
4343
Frozen,
@@ -365,8 +365,7 @@ class DataTree(
365365
MappedDataWithCoords,
366366
DataTreeArithmeticMixin,
367367
TreeAttrAccessMixin,
368-
Generic[Tree],
369-
Mapping,
368+
Mapping[str, "DataArray | DataTree"],
370369
):
371370
"""
372371
A tree-like hierarchical collection of xarray objects.
@@ -701,8 +700,8 @@ def __contains__(self, key: object) -> bool:
701700
def __bool__(self) -> bool:
702701
return bool(self._data_variables) or bool(self._children)
703702

704-
def __iter__(self) -> Iterator[Hashable]:
705-
return itertools.chain(self._data_variables, self._children)
703+
def __iter__(self) -> Iterator[str]:
704+
return itertools.chain(self._data_variables, self._children) # type: ignore
706705

707706
def __array__(self, dtype=None, copy=None):
708707
raise TypeError(

0 commit comments

Comments
 (0)