Skip to content

Commit ff56568

Browse files
authored
Dask 2025.4.0 scheduler info compatibility (#11462)
1 parent 9ad4e24 commit ff56568

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

python-package/xgboost/dask/__init__.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,11 @@ def __init__(self, **args: CollArgsVals) -> None:
254254
super().__init__(**args)
255255

256256
worker = distributed.get_worker()
257-
with distributed.worker_client() as client:
258-
info = client.scheduler_info()
259-
w = info["workers"][worker.address]
260-
wid = w["id"]
261257
# We use task ID for rank assignment which makes the RABIT rank consistent (but
262258
# not the same as task ID is string and "10" is sorted before "2") with dask
263-
# worker ID. This outsources the rank assignment to dask and prevents
259+
# worker name. This outsources the rank assignment to dask and prevents
264260
# non-deterministic issue.
265-
self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{wid}]:" + str(worker.address)
261+
self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{worker.name}]:{worker.address}"
266262

267263

268264
def _get_client(client: Optional["distributed.Client"]) -> "distributed.Client":
@@ -923,12 +919,11 @@ def train( # pylint: disable=unused-argument
923919
924920
"""
925921
client = _get_client(client)
926-
args = locals()
927922
return client.sync(
928923
_train_async,
929924
global_config=config.get_config(),
930925
dconfig=_get_dask_config(),
931-
**args,
926+
**locals(),
932927
)
933928

934929

python-package/xgboost/testing/dask.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dask import array as da
88
from dask import dataframe as dd
99
from distributed import Client, get_worker
10+
from packaging.version import parse as parse_version
1011
from sklearn.datasets import make_classification
1112

1213
import xgboost as xgb
@@ -15,7 +16,7 @@
1516
from xgboost.testing.updater import get_basescore
1617

1718
from .. import dask as dxgb
18-
from ..dask import _get_rabit_args
19+
from ..dask import _DASK_VERSION, _get_rabit_args
1920
from .data import make_batches
2021
from .data import make_categorical as make_cat_local
2122

@@ -179,7 +180,8 @@ def get_rabit_args(client: Client, n_workers: int) -> Any:
179180

180181
def get_client_workers(client: Client) -> List[str]:
181182
"Get workers from a dask client."
182-
workers = client.scheduler_info()["workers"]
183+
kwargs = {"n_workers": -1} if _DASK_VERSION() >= parse_version("2025.4.0") else {}
184+
workers = client.scheduler_info(**kwargs)["workers"]
183185
return list(workers.keys())
184186

185187

0 commit comments

Comments
 (0)