Skip to content

Commit 339ed93

Browse files
scott-hubertydcherianheadtr1ck
authored
Soft import (#9561)
* ENH, TST: aux func for importing optional deps * ENH: use our new helper func for importing optional deps * FIX: use aux func for a few more cftime imports * FIX: remove cruft.... * FIX: Make it play well with mypy Per the proposal at #9561 (comment) This pairs any use of (a now simplified) `attempt_import` with a direct import of the same module, guarded by an `if TYPE_CHECKING` block. * FIX, TST: match error * Update xarray/tests/test_utils.py Co-authored-by: Michael Niklas <mick.niklas@gmail.com> * DOC: add examples section to docstring * refactor: use try-except clause and return original error to user - Also change raise ImportError to raise RuntimeError, since we are catching both ImportError and ModuleNotFoundError * TST: test import of submodules * FIX: Incorporate @headtr1ck suggetsions From #9561 (comment) #9561 (comment) --------- Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: Michael Niklas <mick.niklas@gmail.com>
1 parent 700191b commit 339ed93

File tree

10 files changed

+183
-82
lines changed

10 files changed

+183
-82
lines changed

xarray/backends/common.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
from xarray.core import indexing
1515
from xarray.core.datatree import DataTree
1616
from xarray.core.types import ReadBuffer
17-
from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri
17+
from xarray.core.utils import (
18+
FrozenDict,
19+
NdimSizeLenMixin,
20+
attempt_import,
21+
is_remote_uri,
22+
)
1823
from xarray.namedarray.parallelcompat import get_chunked_array_type
1924
from xarray.namedarray.pycompat import is_chunked_array
2025

@@ -132,14 +137,12 @@ def _find_absolute_paths(
132137
"""
133138
if isinstance(paths, str):
134139
if is_remote_uri(paths) and kwargs.get("engine") == "zarr":
135-
try:
136-
from fsspec.core import get_fs_token_paths
137-
except ImportError as e:
138-
raise ImportError(
139-
"The use of remote URLs for opening zarr requires the package fsspec"
140-
) from e
141-
142-
fs, _, _ = get_fs_token_paths(
140+
if TYPE_CHECKING:
141+
import fsspec
142+
else:
143+
fsspec = attempt_import("fsspec")
144+
145+
fs, _, _ = fsspec.core.get_fs_token_paths(
143146
paths,
144147
mode="rb",
145148
storage_options=kwargs.get("backend_kwargs", {}).get(

xarray/backends/zarr.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from xarray.core.utils import (
2828
FrozenDict,
2929
HiddenKeyDict,
30+
attempt_import,
3031
close_on_error,
3132
emit_user_level_warning,
3233
)
@@ -865,7 +866,10 @@ def store(
865866
dimension on which the zarray will be appended
866867
only needed in append mode
867868
"""
868-
import zarr
869+
if TYPE_CHECKING:
870+
import zarr
871+
else:
872+
zarr = attempt_import("zarr")
869873

870874
existing_keys = tuple(self.zarr_group.array_keys())
871875

@@ -1638,7 +1642,10 @@ def _get_open_params(
16381642
use_zarr_fill_value_as_mask,
16391643
zarr_format,
16401644
):
1641-
import zarr
1645+
if TYPE_CHECKING:
1646+
import zarr
1647+
else:
1648+
zarr = attempt_import("zarr")
16421649

16431650
# zarr doesn't support pathlib.Path objects yet. zarr-python#601
16441651
if isinstance(store, os.PathLike):

xarray/coding/cftime_offsets.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,7 @@
6767
nanosecond_precision_timestamp,
6868
no_default,
6969
)
70-
from xarray.core.utils import emit_user_level_warning
71-
72-
try:
73-
import cftime
74-
except ImportError:
75-
cftime = None
76-
70+
from xarray.core.utils import attempt_import, emit_user_level_warning
7771

7872
if TYPE_CHECKING:
7973
from xarray.core.types import InclusiveOptions, Self, SideOptions, TypeAlias
@@ -93,24 +87,26 @@ def _nanosecond_precision_timestamp(*args, **kwargs):
9387

9488
def get_date_type(calendar, use_cftime=True):
9589
"""Return the cftime date type for a given calendar name."""
96-
if cftime is None:
97-
raise ImportError("cftime is required for dates with non-standard calendars")
90+
if TYPE_CHECKING:
91+
import cftime
9892
else:
99-
if _is_standard_calendar(calendar) and not use_cftime:
100-
return _nanosecond_precision_timestamp
101-
102-
calendars = {
103-
"noleap": cftime.DatetimeNoLeap,
104-
"360_day": cftime.Datetime360Day,
105-
"365_day": cftime.DatetimeNoLeap,
106-
"366_day": cftime.DatetimeAllLeap,
107-
"gregorian": cftime.DatetimeGregorian,
108-
"proleptic_gregorian": cftime.DatetimeProlepticGregorian,
109-
"julian": cftime.DatetimeJulian,
110-
"all_leap": cftime.DatetimeAllLeap,
111-
"standard": cftime.DatetimeGregorian,
112-
}
113-
return calendars[calendar]
93+
cftime = attempt_import("cftime")
94+
95+
if _is_standard_calendar(calendar) and not use_cftime:
96+
return _nanosecond_precision_timestamp
97+
98+
calendars = {
99+
"noleap": cftime.DatetimeNoLeap,
100+
"360_day": cftime.Datetime360Day,
101+
"365_day": cftime.DatetimeNoLeap,
102+
"366_day": cftime.DatetimeAllLeap,
103+
"gregorian": cftime.DatetimeGregorian,
104+
"proleptic_gregorian": cftime.DatetimeProlepticGregorian,
105+
"julian": cftime.DatetimeJulian,
106+
"all_leap": cftime.DatetimeAllLeap,
107+
"standard": cftime.DatetimeGregorian,
108+
}
109+
return calendars[calendar]
114110

115111

116112
class BaseCFTimeOffset:
@@ -141,8 +137,10 @@ def __add__(self, other):
141137
return self.__apply__(other)
142138

143139
def __sub__(self, other):
144-
if cftime is None:
145-
raise ModuleNotFoundError("No module named 'cftime'")
140+
if TYPE_CHECKING:
141+
import cftime
142+
else:
143+
cftime = attempt_import("cftime")
146144

147145
if isinstance(other, cftime.datetime):
148146
raise TypeError("Cannot subtract a cftime.datetime from a time offset.")
@@ -293,8 +291,7 @@ def _adjust_n_years(other, n, month, reference_day):
293291

294292
def _shift_month(date, months, day_option: DayOption = "start"):
295293
"""Shift the date to a month start or end a given number of months away."""
296-
if cftime is None:
297-
raise ModuleNotFoundError("No module named 'cftime'")
294+
_ = attempt_import("cftime")
298295

299296
has_year_zero = date.has_year_zero
300297
delta_year = (date.month + months) // 12
@@ -458,8 +455,10 @@ def onOffset(self, date) -> bool:
458455
return mod_month == 0 and date.day == self._get_offset_day(date)
459456

460457
def __sub__(self, other: Self) -> Self:
461-
if cftime is None:
462-
raise ModuleNotFoundError("No module named 'cftime'")
458+
if TYPE_CHECKING:
459+
import cftime
460+
else:
461+
cftime = attempt_import("cftime")
463462

464463
if isinstance(other, cftime.datetime):
465464
raise TypeError("Cannot subtract cftime.datetime from offset.")
@@ -544,8 +543,10 @@ def __apply__(self, other):
544543
return _shift_month(other, months, self._day_option)
545544

546545
def __sub__(self, other):
547-
if cftime is None:
548-
raise ModuleNotFoundError("No module named 'cftime'")
546+
if TYPE_CHECKING:
547+
import cftime
548+
else:
549+
cftime = attempt_import("cftime")
549550

550551
if isinstance(other, cftime.datetime):
551552
raise TypeError("Cannot subtract cftime.datetime from offset.")
@@ -828,8 +829,10 @@ def delta_to_tick(delta: timedelta | pd.Timedelta) -> Tick:
828829

829830

830831
def to_cftime_datetime(date_str_or_date, calendar=None):
831-
if cftime is None:
832-
raise ModuleNotFoundError("No module named 'cftime'")
832+
if TYPE_CHECKING:
833+
import cftime
834+
else:
835+
cftime = attempt_import("cftime")
833836

834837
if isinstance(date_str_or_date, str):
835838
if calendar is None:
@@ -867,8 +870,10 @@ def _maybe_normalize_date(date, normalize):
867870
def _generate_linear_range(start, end, periods):
868871
"""Generate an equally-spaced sequence of cftime.datetime objects between
869872
and including two dates (whose length equals the number of periods)."""
870-
if cftime is None:
871-
raise ModuleNotFoundError("No module named 'cftime'")
873+
if TYPE_CHECKING:
874+
import cftime
875+
else:
876+
cftime = attempt_import("cftime")
872877

873878
total_seconds = (end - start).total_seconds()
874879
values = np.linspace(0.0, total_seconds, periods, endpoint=True)

xarray/coding/cftimeindex.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,7 @@
5858
)
5959
from xarray.core.common import _contains_cftime_datetimes
6060
from xarray.core.options import OPTIONS
61-
from xarray.core.utils import is_scalar
62-
63-
try:
64-
import cftime
65-
except ImportError:
66-
cftime = None
61+
from xarray.core.utils import attempt_import, is_scalar
6762

6863
if TYPE_CHECKING:
6964
from xarray.coding.cftime_offsets import BaseCFTimeOffset
@@ -130,8 +125,7 @@ def parse_iso8601_like(datetime_string):
130125

131126

132127
def _parse_iso8601_with_reso(date_type, timestr):
133-
if cftime is None:
134-
raise ModuleNotFoundError("No module named 'cftime'")
128+
_ = attempt_import("cftime")
135129

136130
default = date_type(1, 1, 1)
137131
result = parse_iso8601_like(timestr)
@@ -200,8 +194,10 @@ def _field_accessor(name, docstring=None, min_cftime_version="0.0"):
200194
"""Adapted from pandas.tseries.index._field_accessor"""
201195

202196
def f(self, min_cftime_version=min_cftime_version):
203-
if cftime is None:
204-
raise ModuleNotFoundError("No module named 'cftime'")
197+
if TYPE_CHECKING:
198+
import cftime
199+
else:
200+
cftime = attempt_import("cftime")
205201

206202
if Version(cftime.__version__) >= Version(min_cftime_version):
207203
return get_date_field(self._data, name)
@@ -225,8 +221,10 @@ def get_date_type(self):
225221

226222

227223
def assert_all_valid_date_type(data):
228-
if cftime is None:
229-
raise ModuleNotFoundError("No module named 'cftime'")
224+
if TYPE_CHECKING:
225+
import cftime
226+
else:
227+
cftime = attempt_import("cftime")
230228

231229
if len(data) > 0:
232230
sample = data[0]
@@ -803,6 +801,10 @@ def round(self, freq):
803801

804802
@property
805803
def is_leap_year(self):
804+
if TYPE_CHECKING:
805+
import cftime
806+
else:
807+
cftime = attempt_import("cftime")
806808
func = np.vectorize(cftime.is_leap_year)
807809
return func(self.year, calendar=self.calendar)
808810

xarray/coding/times.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable, Hashable
66
from datetime import datetime, timedelta
77
from functools import partial
8-
from typing import Literal, Union, cast
8+
from typing import TYPE_CHECKING, Literal, Union, cast
99

1010
import numpy as np
1111
import pandas as pd
@@ -25,7 +25,7 @@
2525
from xarray.core.duck_array_ops import asarray, ravel, reshape
2626
from xarray.core.formatting import first_n_items, format_timestamp, last_item
2727
from xarray.core.pdcompat import nanosecond_precision_timestamp
28-
from xarray.core.utils import emit_user_level_warning
28+
from xarray.core.utils import attempt_import, emit_user_level_warning
2929
from xarray.core.variable import Variable
3030
from xarray.namedarray.parallelcompat import T_ChunkedArray, get_chunked_array_type
3131
from xarray.namedarray.pycompat import is_chunked_array
@@ -235,8 +235,10 @@ def _decode_cf_datetime_dtype(
235235
def _decode_datetime_with_cftime(
236236
num_dates: np.ndarray, units: str, calendar: str
237237
) -> np.ndarray:
238-
if cftime is None:
239-
raise ModuleNotFoundError("No module named 'cftime'")
238+
if TYPE_CHECKING:
239+
import cftime
240+
else:
241+
cftime = attempt_import("cftime")
240242
if num_dates.size > 0:
241243
return np.asarray(
242244
cftime.num2date(num_dates, units, calendar, only_use_cftime_datetimes=True)
@@ -634,8 +636,10 @@ def _encode_datetime_with_cftime(dates, units: str, calendar: str) -> np.ndarray
634636
This method is more flexible than xarray's parsing using datetime64[ns]
635637
arrays but also slower because it loops over each element.
636638
"""
637-
if cftime is None:
638-
raise ModuleNotFoundError("No module named 'cftime'")
639+
if TYPE_CHECKING:
640+
import cftime
641+
else:
642+
cftime = attempt_import("cftime")
639643

640644
if np.issubdtype(dates.dtype, np.datetime64):
641645
# numpy's broken datetime conversion only works for us precision

xarray/core/utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
import contextlib
4040
import functools
41+
import importlib
4142
import inspect
4243
import io
4344
import itertools
@@ -64,7 +65,7 @@
6465
)
6566
from enum import Enum
6667
from pathlib import Path
67-
from types import EllipsisType
68+
from types import EllipsisType, ModuleType
6869
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, TypeVar, overload
6970

7071
import numpy as np
@@ -1194,6 +1195,60 @@ def _resolve_doubly_passed_kwarg(
11941195
return kwargs_dict
11951196

11961197

1198+
def attempt_import(module: str) -> ModuleType:
1199+
"""Import an optional dependency, and raise an informative error on failure.
1200+
1201+
Parameters
1202+
----------
1203+
module : str
1204+
Module to import. For example, ``'zarr'`` or ``'matplotlib.pyplot'``.
1205+
1206+
Returns
1207+
-------
1208+
module : ModuleType
1209+
The Imported module.
1210+
1211+
Raises
1212+
------
1213+
ImportError
1214+
If the module could not be imported.
1215+
1216+
Notes
1217+
-----
1218+
Static type checkers will not be able to infer the type of the returned module,
1219+
so it is recommended to precede this function with a direct import of the module,
1220+
guarded by an ``if TYPE_CHECKING`` block, to preserve type checker functionality.
1221+
See the examples section below for a demonstration.
1222+
1223+
Examples
1224+
--------
1225+
>>> from xarray.core.utils import attempt_import
1226+
>>> if TYPE_CHECKING:
1227+
... import zarr
1228+
... else:
1229+
... zarr = attempt_import("zarr")
1230+
...
1231+
"""
1232+
install_mapping = dict(nc_time_axis="nc-time-axis")
1233+
package_purpose = dict(
1234+
zarr="for working with Zarr stores",
1235+
cftime="for working with non-standard calendars",
1236+
matplotlib="for plotting",
1237+
hypothesis="for the `xarray.testing.strategies` submodule",
1238+
)
1239+
package_name = module.split(".")[0] # e.g. "zarr" from "zarr.storage"
1240+
install_name = install_mapping.get(package_name, package_name)
1241+
reason = package_purpose.get(package_name, "")
1242+
try:
1243+
return importlib.import_module(module)
1244+
except (ImportError, ModuleNotFoundError) as e:
1245+
raise ImportError(
1246+
f"The {install_name} package is required {reason}"
1247+
" but could not be imported."
1248+
" Please install it with your package manager (e.g. conda or pip)."
1249+
) from e
1250+
1251+
11971252
_DEFAULT_NAME = ReprObject("<default-name>")
11981253

11991254

0 commit comments

Comments
 (0)