Skip to content

Commit 4550a01

Browse files
Illviljanpre-commit-ci[bot]andersy005
authored
Add expand_dims (#8407)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com>
1 parent c93b31a commit 4550a01

File tree

6 files changed

+78
-20
lines changed

6 files changed

+78
-20
lines changed

xarray/core/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2596,7 +2596,7 @@ def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable:
25962596
"""
25972597
Use sparse-array as backend.
25982598
"""
2599-
from xarray.namedarray.utils import _default as _default_named
2599+
from xarray.namedarray._typing import _default as _default_named
26002600

26012601
if sparse_format is _default:
26022602
sparse_format = _default_named

xarray/namedarray/_array_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import numpy as np
88

99
from xarray.namedarray._typing import (
10+
Default,
1011
_arrayapi,
12+
_Axis,
13+
_default,
14+
_Dim,
1115
_DType,
1216
_ScalarType,
1317
_ShapeType,
@@ -144,3 +148,51 @@ def real(
144148
xp = _get_data_namespace(x)
145149
out = x._new(data=xp.real(x._data))
146150
return out
151+
152+
153+
# %% Manipulation functions
154+
def expand_dims(
155+
x: NamedArray[Any, _DType],
156+
/,
157+
*,
158+
dim: _Dim | Default = _default,
159+
axis: _Axis = 0,
160+
) -> NamedArray[Any, _DType]:
161+
"""
162+
Expands the shape of an array by inserting a new dimension of size one at the
163+
position specified by dims.
164+
165+
Parameters
166+
----------
167+
x :
168+
Array to expand.
169+
dim :
170+
Dimension name. New dimension will be stored in the axis position.
171+
axis :
172+
(Not recommended) Axis position (zero-based). Default is 0.
173+
174+
Returns
175+
-------
176+
out :
177+
An expanded output array having the same data type as x.
178+
179+
Examples
180+
--------
181+
>>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]]))
182+
>>> expand_dims(x)
183+
<xarray.NamedArray (dim_2: 1, x: 2, y: 2)>
184+
Array([[[1., 2.],
185+
[3., 4.]]], dtype=float64)
186+
>>> expand_dims(x, dim="z")
187+
<xarray.NamedArray (z: 1, x: 2, y: 2)>
188+
Array([[[1., 2.],
189+
[3., 4.]]], dtype=float64)
190+
"""
191+
xp = _get_data_namespace(x)
192+
dims = x.dims
193+
if dim is _default:
194+
dim = f"dim_{len(dims)}"
195+
d = list(dims)
196+
d.insert(axis, dim)
197+
out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
198+
return out

xarray/namedarray/_typing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
from collections.abc import Hashable, Iterable, Mapping, Sequence
4+
from enum import Enum
45
from types import ModuleType
56
from typing import (
67
Any,
78
Callable,
9+
Final,
810
Protocol,
911
SupportsIndex,
1012
TypeVar,
@@ -15,6 +17,14 @@
1517

1618
import numpy as np
1719

20+
21+
# Singleton type, as per https://github.com/python/typing/pull/240
22+
class Default(Enum):
23+
token: Final = 0
24+
25+
26+
_default = Default.token
27+
1828
# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array
1929
_T = TypeVar("_T")
2030
_T_co = TypeVar("_T_co", covariant=True)
@@ -49,6 +59,10 @@ def dtype(self) -> _DType_co:
4959
_ShapeType = TypeVar("_ShapeType", bound=Any)
5060
_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True)
5161

62+
_Axis = int
63+
_Axes = tuple[_Axis, ...]
64+
_AxisLike = Union[_Axis, _Axes]
65+
5266
_Chunks = tuple[_Shape, ...]
5367

5468
_Dim = Hashable

xarray/namedarray/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_arrayapi,
2626
_arrayfunction_or_api,
2727
_chunkedarray,
28+
_default,
2829
_dtype,
2930
_DType_co,
3031
_ScalarType_co,
@@ -33,13 +34,14 @@
3334
_SupportsImag,
3435
_SupportsReal,
3536
)
36-
from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array
37+
from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array
3738

3839
if TYPE_CHECKING:
3940
from numpy.typing import ArrayLike, NDArray
4041

4142
from xarray.core.types import Dims
4243
from xarray.namedarray._typing import (
44+
Default,
4345
_AttrsLike,
4446
_Chunks,
4547
_Dim,
@@ -52,7 +54,6 @@
5254
_ShapeType,
5355
duckarray,
5456
)
55-
from xarray.namedarray.utils import Default
5657

5758
try:
5859
from dask.typing import (

xarray/namedarray/utils.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22

33
import sys
44
from collections.abc import Hashable
5-
from enum import Enum
6-
from typing import (
7-
TYPE_CHECKING,
8-
Any,
9-
Final,
10-
)
5+
from typing import TYPE_CHECKING, Any
116

127
import numpy as np
138

@@ -31,14 +26,6 @@
3126
DaskCollection: Any = NDArray # type: ignore
3227

3328

34-
# Singleton type, as per https://github.com/python/typing/pull/240
35-
class Default(Enum):
36-
token: Final = 0
37-
38-
39-
_default = Default.token
40-
41-
4229
def module_available(module: str) -> bool:
4330
"""Checks whether a module is installed without importing it.
4431

xarray/tests/test_namedarray.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,27 @@
1010
import pytest
1111

1212
from xarray.core.indexing import ExplicitlyIndexed
13-
from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co
13+
from xarray.namedarray._typing import (
14+
_arrayfunction_or_api,
15+
_default,
16+
_DType_co,
17+
_ShapeType_co,
18+
)
1419
from xarray.namedarray.core import NamedArray, from_array
15-
from xarray.namedarray.utils import _default
1620

1721
if TYPE_CHECKING:
1822
from types import ModuleType
1923

2024
from numpy.typing import ArrayLike, DTypeLike, NDArray
2125

2226
from xarray.namedarray._typing import (
27+
Default,
2328
_AttrsLike,
2429
_DimsLike,
2530
_DType,
2631
_Shape,
2732
duckarray,
2833
)
29-
from xarray.namedarray.utils import Default
3034

3135

3236
class CustomArrayBase(Generic[_ShapeType_co, _DType_co]):

0 commit comments

Comments
 (0)