|
55 | 55 | import logging
|
56 | 56 | from collections import defaultdict
|
57 | 57 | from contextlib import contextmanager
|
58 |
| -from functools import partial, update_wrapper |
| 58 | +from functools import cache, partial, update_wrapper |
59 | 59 | from threading import Thread
|
60 | 60 | from typing import (
|
61 | 61 | Any,
|
|
85 | 85 | from dask import dataframe as dd
|
86 | 86 | from dask.delayed import Delayed
|
87 | 87 | from distributed import Future
|
| 88 | +from packaging.version import Version |
| 89 | +from packaging.version import parse as parse_version |
88 | 90 |
|
89 | 91 | from .. import collective, config
|
90 | 92 | from .._typing import FeatureNames, FeatureTypes, IterationRange
|
|
171 | 173 | LOGGER = logging.getLogger("[xgboost.dask]")
|
172 | 174 |
|
173 | 175 |
|
| 176 | +@cache |
| 177 | +def _DASK_VERSION() -> Version: |
| 178 | + return parse_version(dask.__version__) |
| 179 | + |
| 180 | + |
| 181 | +@cache |
| 182 | +def _DASK_2024_12_1() -> bool: |
| 183 | + return _DASK_VERSION() >= parse_version("2024.12.1") |
| 184 | + |
| 185 | + |
| 186 | +@cache |
| 187 | +def _DASK_2025_3_0() -> bool: |
| 188 | + return _DASK_VERSION() >= parse_version("2025.3.0") |
| 189 | + |
| 190 | + |
174 | 191 | def _try_start_tracker(
|
175 | 192 | n_workers: int,
|
176 | 193 | addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
|
@@ -1491,6 +1508,18 @@ async def _predict_async(
|
1491 | 1508 | )
|
1492 | 1509 | if isinstance(predts, dd.DataFrame):
|
1493 | 1510 | predts = predts.to_dask_array()
|
| 1511 | + # Make sure the booster is part of the task graph implicitly |
| 1512 | + # only needed for certain versions of dask. |
| 1513 | + if _DASK_2024_12_1() and not _DASK_2025_3_0(): |
| 1514 | + # Fixes this issue for dask>=2024.1.1,<2025.3.0 |
| 1515 | + # Dask==2025.3.0 fails with: |
| 1516 | + # RuntimeError: Attempting to use an asynchronous |
| 1517 | + # Client in a synchronous context of `dask.compute` |
| 1518 | + # |
| 1519 | + # Dask==2025.4.0 fails with: |
| 1520 | + # TypeError: Value type is not supported for data |
| 1521 | + # iterator:<class 'distributed.client.Future'> |
| 1522 | + predts = predts.persist() |
1494 | 1523 | return predts
|
1495 | 1524 |
|
1496 | 1525 | @_deprecate_positional_args
|
|
0 commit comments