Skip to content

Commit 1873874

Browse files
authored
Start splitting up dataset.py (#10039)
* Start splitting up `dataset.py` Currently, `dataset.py` is 10956 lines long. This makes doing any work with current LLMs basically impossible — with Claude's tokenizer, the file is 104K tokens, or >2.5x the size of the _per-minute_ rate limit for basic accounts. Most of xarray touches it in some way, so you generally want to give it the file for context. Even if you don't think "LLMs are the future, let's code with vibes!", the file is still really long; can be difficult to navigate (though OTOH it can be easy to just grep, to be fair...). So I would propose: - We start breaking it up, while also being cognizant that big changes can cause merge conflicts - Start with the low-hanging fruit - For example, this PR moves code outside the class (but that's quite limited) - Then move some of the code from the big methods into functions in other files, like `curve_fit` - Possibly (has tradeoffs; needs discussion) build some mixins so we can split up the class, if we want to have much smaller files - We can also think about other files: `dataarray.py` is 7.5K lines. The tests are also huge (`test_dataset` is 7.5K lines), but unlike with the library code, we can copy out & in chunks of tests when developing. (Note that I don't have any strong views on exactly what code should go in which file; I made a quick guess — very open to any suggestions; also easy to change later, particularly since this code doesn't change much so is less likely to cause conflicts) * .
1 parent 54946eb commit 1873874

File tree

5 files changed

+166
-132
lines changed

5 files changed

+166
-132
lines changed

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]:
893893
return dict(zip(self.dims, key, strict=True))
894894

895895
def _getitem_coord(self, key: Any) -> Self:
896-
from xarray.core.dataset import _get_virtual_variable
896+
from xarray.core.dataset_utils import _get_virtual_variable
897897

898898
try:
899899
var = self._coords[key]

xarray/core/dataset.py

Lines changed: 4 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@
2424
from operator import methodcaller
2525
from os import PathLike
2626
from types import EllipsisType
27-
from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast, overload
27+
from typing import IO, TYPE_CHECKING, Any, Literal, cast, overload
2828

2929
import numpy as np
3030
from pandas.api.types import is_extension_array_dtype
3131

32+
from xarray.core.dataset_utils import _get_virtual_variable, _LocIndexer
33+
from xarray.core.dataset_variables import DataVariables
34+
3235
# remove once numpy 2.0 is the oldest supported version
3336
try:
3437
from numpy.exceptions import RankWarning
@@ -98,7 +101,6 @@
98101
T_ChunksFreq,
99102
T_DataArray,
100103
T_DataArrayOrSet,
101-
T_Dataset,
102104
ZarrWriteModes,
103105
)
104106
from xarray.core.utils import (
@@ -196,43 +198,6 @@
196198
]
197199

198200

199-
def _get_virtual_variable(
200-
variables, key: Hashable, dim_sizes: Mapping | None = None
201-
) -> tuple[Hashable, Hashable, Variable]:
202-
"""Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
203-
objects (if possible)
204-
205-
"""
206-
from xarray.core.dataarray import DataArray
207-
208-
if dim_sizes is None:
209-
dim_sizes = {}
210-
211-
if key in dim_sizes:
212-
data = pd.Index(range(dim_sizes[key]), name=key)
213-
variable = IndexVariable((key,), data)
214-
return key, key, variable
215-
216-
if not isinstance(key, str):
217-
raise KeyError(key)
218-
219-
split_key = key.split(".", 1)
220-
if len(split_key) != 2:
221-
raise KeyError(key)
222-
223-
ref_name, var_name = split_key
224-
ref_var = variables[ref_name]
225-
226-
if _contains_datetime_like_objects(ref_var):
227-
ref_var = DataArray(ref_var)
228-
data = getattr(ref_var.dt, var_name).data
229-
else:
230-
data = getattr(ref_var, var_name).data
231-
virtual_var = Variable(ref_var.dims, data)
232-
233-
return ref_name, var_name, virtual_var
234-
235-
236201
def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
237202
"""
238203
Return map from each dim to chunk sizes, accounting for backend's preferred chunks.
@@ -367,19 +332,6 @@ def _maybe_chunk(
367332
return var
368333

369334

370-
def as_dataset(obj: Any) -> Dataset:
371-
"""Cast the given object to a Dataset.
372-
373-
Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
374-
object is only created if the provided object is not already one.
375-
"""
376-
if hasattr(obj, "to_dataset"):
377-
obj = obj.to_dataset()
378-
if not isinstance(obj, Dataset):
379-
obj = Dataset(obj)
380-
return obj
381-
382-
383335
def _get_func_args(func, param_names):
384336
"""Use `inspect.signature` to try accessing `func` args. Otherwise, ensure
385337
they are provided by user.
@@ -468,84 +420,6 @@ def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult:
468420
)
469421

470422

471-
class DataVariables(Mapping[Any, "DataArray"]):
472-
__slots__ = ("_dataset",)
473-
474-
def __init__(self, dataset: Dataset):
475-
self._dataset = dataset
476-
477-
def __iter__(self) -> Iterator[Hashable]:
478-
return (
479-
key
480-
for key in self._dataset._variables
481-
if key not in self._dataset._coord_names
482-
)
483-
484-
def __len__(self) -> int:
485-
length = len(self._dataset._variables) - len(self._dataset._coord_names)
486-
assert length >= 0, "something is wrong with Dataset._coord_names"
487-
return length
488-
489-
def __contains__(self, key: Hashable) -> bool:
490-
return key in self._dataset._variables and key not in self._dataset._coord_names
491-
492-
def __getitem__(self, key: Hashable) -> DataArray:
493-
if key not in self._dataset._coord_names:
494-
return self._dataset[key]
495-
raise KeyError(key)
496-
497-
def __repr__(self) -> str:
498-
return formatting.data_vars_repr(self)
499-
500-
@property
501-
def variables(self) -> Mapping[Hashable, Variable]:
502-
all_variables = self._dataset.variables
503-
return Frozen({k: all_variables[k] for k in self})
504-
505-
@property
506-
def dtypes(self) -> Frozen[Hashable, np.dtype]:
507-
"""Mapping from data variable names to dtypes.
508-
509-
Cannot be modified directly, but is updated when adding new variables.
510-
511-
See Also
512-
--------
513-
Dataset.dtype
514-
"""
515-
return self._dataset.dtypes
516-
517-
def _ipython_key_completions_(self):
518-
"""Provide method for the key-autocompletions in IPython."""
519-
return [
520-
key
521-
for key in self._dataset._ipython_key_completions_()
522-
if key not in self._dataset._coord_names
523-
]
524-
525-
526-
class _LocIndexer(Generic[T_Dataset]):
527-
__slots__ = ("dataset",)
528-
529-
def __init__(self, dataset: T_Dataset):
530-
self.dataset = dataset
531-
532-
def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
533-
if not utils.is_dict_like(key):
534-
raise TypeError("can only lookup dictionaries from Dataset.loc")
535-
return self.dataset.sel(key)
536-
537-
def __setitem__(self, key, value) -> None:
538-
if not utils.is_dict_like(key):
539-
raise TypeError(
540-
"can only set locations defined by dictionaries from Dataset.loc."
541-
f" Got: {key}"
542-
)
543-
544-
# set new values
545-
dim_indexers = map_index_queries(self.dataset, key).dim_indexers
546-
self.dataset[dim_indexers] = value
547-
548-
549423
class Dataset(
550424
DataWithCoords,
551425
DatasetAggregations,

xarray/core/dataset_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
import typing
4+
from collections.abc import Hashable, Mapping
5+
from typing import Any, Generic
6+
7+
import pandas as pd
8+
9+
from xarray.core import utils
10+
from xarray.core.common import _contains_datetime_like_objects
11+
from xarray.core.indexing import map_index_queries
12+
from xarray.core.types import T_Dataset
13+
from xarray.core.variable import IndexVariable, Variable
14+
15+
if typing.TYPE_CHECKING:
16+
from xarray.core.dataset import Dataset
17+
18+
19+
class _LocIndexer(Generic[T_Dataset]):
20+
__slots__ = ("dataset",)
21+
22+
def __init__(self, dataset: T_Dataset):
23+
self.dataset = dataset
24+
25+
def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
26+
if not utils.is_dict_like(key):
27+
raise TypeError("can only lookup dictionaries from Dataset.loc")
28+
return self.dataset.sel(key)
29+
30+
def __setitem__(self, key, value) -> None:
31+
if not utils.is_dict_like(key):
32+
raise TypeError(
33+
"can only set locations defined by dictionaries from Dataset.loc."
34+
f" Got: {key}"
35+
)
36+
37+
# set new values
38+
dim_indexers = map_index_queries(self.dataset, key).dim_indexers
39+
self.dataset[dim_indexers] = value
40+
41+
42+
def as_dataset(obj: Any) -> Dataset:
43+
"""Cast the given object to a Dataset.
44+
45+
Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
46+
object is only created if the provided object is not already one.
47+
"""
48+
from xarray.core.dataset import Dataset
49+
50+
if hasattr(obj, "to_dataset"):
51+
obj = obj.to_dataset()
52+
if not isinstance(obj, Dataset):
53+
obj = Dataset(obj)
54+
return obj
55+
56+
57+
def _get_virtual_variable(
58+
variables, key: Hashable, dim_sizes: Mapping | None = None
59+
) -> tuple[Hashable, Hashable, Variable]:
60+
"""Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
61+
objects (if possible)
62+
63+
"""
64+
from xarray.core.dataarray import DataArray
65+
66+
if dim_sizes is None:
67+
dim_sizes = {}
68+
69+
if key in dim_sizes:
70+
data = pd.Index(range(dim_sizes[key]), name=key)
71+
variable = IndexVariable((key,), data)
72+
return key, key, variable
73+
74+
if not isinstance(key, str):
75+
raise KeyError(key)
76+
77+
split_key = key.split(".", 1)
78+
if len(split_key) != 2:
79+
raise KeyError(key)
80+
81+
ref_name, var_name = split_key
82+
ref_var = variables[ref_name]
83+
84+
if _contains_datetime_like_objects(ref_var):
85+
ref_var = DataArray(ref_var)
86+
data = getattr(ref_var.dt, var_name).data
87+
else:
88+
data = getattr(ref_var, var_name).data
89+
virtual_var = Variable(ref_var.dims, data)
90+
91+
return ref_name, var_name, virtual_var

xarray/core/dataset_variables.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import typing
2+
from collections.abc import Hashable, Iterator, Mapping
3+
from typing import Any
4+
5+
import numpy as np
6+
7+
from xarray.core import formatting
8+
from xarray.core.utils import Frozen
9+
from xarray.core.variable import Variable
10+
11+
if typing.TYPE_CHECKING:
12+
from xarray.core.dataarray import DataArray
13+
from xarray.core.dataset import Dataset
14+
15+
16+
class DataVariables(Mapping[Any, "DataArray"]):
17+
__slots__ = ("_dataset",)
18+
19+
def __init__(self, dataset: "Dataset"):
20+
self._dataset = dataset
21+
22+
def __iter__(self) -> Iterator[Hashable]:
23+
return (
24+
key
25+
for key in self._dataset._variables
26+
if key not in self._dataset._coord_names
27+
)
28+
29+
def __len__(self) -> int:
30+
length = len(self._dataset._variables) - len(self._dataset._coord_names)
31+
assert length >= 0, "something is wrong with Dataset._coord_names"
32+
return length
33+
34+
def __contains__(self, key: Hashable) -> bool:
35+
return key in self._dataset._variables and key not in self._dataset._coord_names
36+
37+
def __getitem__(self, key: Hashable) -> "DataArray":
38+
if key not in self._dataset._coord_names:
39+
return self._dataset[key]
40+
raise KeyError(key)
41+
42+
def __repr__(self) -> str:
43+
return formatting.data_vars_repr(self)
44+
45+
@property
46+
def variables(self) -> Mapping[Hashable, Variable]:
47+
all_variables = self._dataset.variables
48+
return Frozen({k: all_variables[k] for k in self})
49+
50+
@property
51+
def dtypes(self) -> Frozen[Hashable, np.dtype]:
52+
"""Mapping from data variable names to dtypes.
53+
54+
Cannot be modified directly, but is updated when adding new variables.
55+
56+
See Also
57+
--------
58+
Dataset.dtype
59+
"""
60+
return self._dataset.dtypes
61+
62+
def _ipython_key_completions_(self):
63+
"""Provide method for the key-autocompletions in IPython."""
64+
return [
65+
key
66+
for key in self._dataset._ipython_key_completions_()
67+
if key not in self._dataset._coord_names
68+
]

xarray/core/datatree.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from xarray.core.common import TreeAttrAccessMixin, get_chunksizes
2222
from xarray.core.coordinates import Coordinates, DataTreeCoordinates
2323
from xarray.core.dataarray import DataArray
24-
from xarray.core.dataset import Dataset, DataVariables
24+
from xarray.core.dataset import Dataset
25+
from xarray.core.dataset_variables import DataVariables
2526
from xarray.core.datatree_mapping import (
2627
map_over_datasets,
2728
)

0 commit comments

Comments
 (0)