File tree 2 files changed +7
-10
lines changed
2 files changed +7
-10
lines changed Original file line number Diff line number Diff line change @@ -254,15 +254,11 @@ def __init__(self, **args: CollArgsVals) -> None:
254
254
super ().__init__ (** args )
255
255
256
256
worker = distributed .get_worker ()
257
- with distributed .worker_client () as client :
258
- info = client .scheduler_info ()
259
- w = info ["workers" ][worker .address ]
260
- wid = w ["id" ]
261
257
# We use task ID for rank assignment which makes the RABIT rank consistent (but
262
258
# not the same as task ID is string and "10" is sorted before "2") with dask
263
- # worker ID . This outsources the rank assignment to dask and prevents
259
+ # worker name . This outsources the rank assignment to dask and prevents
264
260
# non-deterministic issue.
265
- self .args ["DMLC_TASK_ID" ] = f"[xgboost.dask-{ wid } ]:" + str ( worker .address )
261
+ self .args ["DMLC_TASK_ID" ] = f"[xgboost.dask-{ worker . name } ]:{ worker .address } "
266
262
267
263
268
264
def _get_client (client : Optional ["distributed.Client" ]) -> "distributed.Client" :
@@ -923,12 +919,11 @@ def train( # pylint: disable=unused-argument
923
919
924
920
"""
925
921
client = _get_client (client )
926
- args = locals ()
927
922
return client .sync (
928
923
_train_async ,
929
924
global_config = config .get_config (),
930
925
dconfig = _get_dask_config (),
931
- ** args ,
926
+ ** locals () ,
932
927
)
933
928
934
929
Original file line number Diff line number Diff line change 7
7
from dask import array as da
8
8
from dask import dataframe as dd
9
9
from distributed import Client , get_worker
10
+ from packaging .version import parse as parse_version
10
11
from sklearn .datasets import make_classification
11
12
12
13
import xgboost as xgb
15
16
from xgboost .testing .updater import get_basescore
16
17
17
18
from .. import dask as dxgb
18
- from ..dask import _get_rabit_args
19
+ from ..dask import _DASK_VERSION , _get_rabit_args
19
20
from .data import make_batches
20
21
from .data import make_categorical as make_cat_local
21
22
@@ -179,7 +180,8 @@ def get_rabit_args(client: Client, n_workers: int) -> Any:
179
180
180
181
def get_client_workers (client : Client ) -> List [str ]:
181
182
"Get workers from a dask client."
182
- workers = client .scheduler_info ()["workers" ]
183
+ kwargs = {"n_workers" : - 1 } if _DASK_VERSION () >= parse_version ("2025.4.0" ) else {}
184
+ workers = client .scheduler_info (** kwargs )["workers" ]
183
185
return list (workers .keys ())
184
186
185
187
You can’t perform that action at this time.
0 commit comments