Skip to content

Commit d114196

Browse files
[dask] Workarounds for different Dask versions. (#11436)
--------- Co-authored-by: TomAugspurger <toaugspurger@nvidia.com>
1 parent f41be2c commit d114196

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

python-package/xgboost/dask/__init__.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
import logging
5656
from collections import defaultdict
5757
from contextlib import contextmanager
58-
from functools import partial, update_wrapper
58+
from functools import cache, partial, update_wrapper
5959
from threading import Thread
6060
from typing import (
6161
Any,
@@ -85,6 +85,8 @@
8585
from dask import dataframe as dd
8686
from dask.delayed import Delayed
8787
from distributed import Future
88+
from packaging.version import Version
89+
from packaging.version import parse as parse_version
8890

8991
from .. import collective, config
9092
from .._typing import FeatureNames, FeatureTypes, IterationRange
@@ -171,6 +173,21 @@
171173
LOGGER = logging.getLogger("[xgboost.dask]")
172174

173175

176+
@cache
177+
def _DASK_VERSION() -> Version:
178+
return parse_version(dask.__version__)
179+
180+
181+
@cache
182+
def _DASK_2024_12_1() -> bool:
183+
return _DASK_VERSION() >= parse_version("2024.12.1")
184+
185+
186+
@cache
187+
def _DASK_2025_3_0() -> bool:
188+
return _DASK_VERSION() >= parse_version("2025.3.0")
189+
190+
174191
def _try_start_tracker(
175192
n_workers: int,
176193
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
@@ -1491,6 +1508,18 @@ async def _predict_async(
14911508
)
14921509
if isinstance(predts, dd.DataFrame):
14931510
predts = predts.to_dask_array()
1511+
# Make sure the booster is part of the task graph implicitly
1512+
# only needed for certain versions of dask.
1513+
if _DASK_2024_12_1() and not _DASK_2025_3_0():
1514+
# Fixes this issue for dask>=2024.1.1,<2025.3.0
1515+
# Dask==2025.3.0 fails with:
1516+
# RuntimeError: Attempting to use an asynchronous
1517+
# Client in a synchronous context of `dask.compute`
1518+
#
1519+
# Dask==2025.4.0 fails with:
1520+
# TypeError: Value type is not supported for data
1521+
# iterator:<class 'distributed.client.Future'>
1522+
predts = predts.persist()
14941523
return predts
14951524

14961525
@_deprecate_positional_args

0 commit comments

Comments
 (0)